diff options
author | Christian Pointner <equinox@anytun.org> | 2017-11-04 17:42:36 +0100 |
---|---|---|
committer | Christian Pointner <equinox@anytun.org> | 2017-11-04 17:42:36 +0100 |
commit | 9001741588299e21b35b187ced9125967d5bcad5 (patch) | |
tree | 6de627205691e12331e495a87d16f66b28e858c1 /satp | |
parent | some tests for SA (cont'd) (diff) |
some tests for SA (cont'd)
Diffstat (limited to 'satp')
-rw-r--r-- | satp/security-association.go | 7 | ||||
-rw-r--r-- | satp/security-association_test.go | 48 |
2 files changed, 51 insertions, 4 deletions
diff --git a/satp/security-association.go b/satp/security-association.go index c5093de..00798f9 100644 --- a/satp/security-association.go +++ b/satp/security-association.go @@ -74,9 +74,12 @@ func (sa *SecurityAssociation) EndpointCompareAndUpdate(idx uint, ep *net.UDPAdd return false } -func (sa *SecurityAssociation) GetEndpointsAndNextSequenceNumber() (seqNum uint32, eps []*net.UDPAddr) { +func (sa *SecurityAssociation) GetEndpointsAndNextSequenceNumber(epsIn []*net.UDPAddr) (seqNum uint32, eps []*net.UDPAddr) { seqNum = atomic.AddUint32(&sa.nextSeqNr, 1) - 1 - eps = make([]*net.UDPAddr, len(sa.endpoints)) + eps = epsIn + if eps == nil { + eps = make([]*net.UDPAddr, len(sa.endpoints)) + } for i := range sa.endpoints { eps[i] = (*net.UDPAddr)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&(sa.endpoints[i]))))) } diff --git a/satp/security-association_test.go b/satp/security-association_test.go index 4bbf7b7..2b1d3cd 100644 --- a/satp/security-association_test.go +++ b/satp/security-association_test.go @@ -159,10 +159,10 @@ func TestSecurityAssociationEndpointUpdate(t *testing.T) { t.Fatalf("endpoints[0] is %v but should be %v", sa.endpoints[0], addr4) } if sa.endpoints[1] != nil { - t.Fatalf("endpoints[1] is %v but should be nil", sa.endpoints[0]) + t.Fatalf("endpoints[1] is %v but should be nil", sa.endpoints[1]) } if !EndpointsEqual(sa.endpoints[2], addr6) { - t.Fatalf("endpoints[2] is %v but should be %v", sa.endpoints[0], addr6) + t.Fatalf("endpoints[2] is %v but should be %v", sa.endpoints[2], addr6) } } @@ -200,3 +200,47 @@ func TestSecurityAssociationEndpointCompareAndUpdate(t *testing.T) { t.Fatalf("updateting endpoint %v with %v should return changed = true", addr4, addr6) } } + +func TestSecurityAssociationGetEndpointsAndNextSequenceNumber(t *testing.T) { + sa := NewSecurityAssociation(nil, 3, 0, 0) + + seq, _ := sa.GetEndpointsAndNextSequenceNumber(nil) + if seq != 0 { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, 0) + } + seq, _ = sa.GetEndpointsAndNextSequenceNumber(nil) + seq, _ = sa.GetEndpointsAndNextSequenceNumber(nil) + seq, _ = sa.GetEndpointsAndNextSequenceNumber(nil) + if seq != 3 { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, 3) + } + + sa = NewSecurityAssociation(nil, 3, (^uint32(0)), 0) + eps := make([]*net.UDPAddr, 3) + for i := range eps { + if eps[i] != nil { + t.Fatalf("endpoints[%d] is %v but should be nil", i) + } + } + seq, _ = sa.GetEndpointsAndNextSequenceNumber(eps) + if seq != (^uint32(0)) { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, (^uint32(0))) + } + addr4, _ := net.ResolveUDPAddr("udp4", "1.2.3.4:444") + sa.EndpointUpdate(0, addr4) + addr6, _ := net.ResolveUDPAddr("udp6", "[2a01:1234::2]:666") + sa.EndpointUpdate(2, addr6) + seq, _ = sa.GetEndpointsAndNextSequenceNumber(eps) + if seq != 0 { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, 0) + } + if !EndpointsEqual(sa.endpoints[0], addr4) { + t.Fatalf("endpoints[0] is %v but should be %v", sa.endpoints[0], addr4) + } + if sa.endpoints[1] != nil { + t.Fatalf("endpoints[1] is %v but should be nil", sa.endpoints[1]) + } + if !EndpointsEqual(sa.endpoints[2], addr6) { + t.Fatalf("endpoints[2] is %v but should be %v", sa.endpoints[2], addr6) + } +} |