From 9001741588299e21b35b187ced9125967d5bcad5 Mon Sep 17 00:00:00 2001 From: Christian Pointner Date: Sat, 4 Nov 2017 17:42:36 +0100 Subject: some tests for SA (cont'd) --- satp/security-association.go | 7 ++++-- satp/security-association_test.go | 48 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/satp/security-association.go b/satp/security-association.go index c5093de..00798f9 100644 --- a/satp/security-association.go +++ b/satp/security-association.go @@ -74,9 +74,12 @@ func (sa *SecurityAssociation) EndpointCompareAndUpdate(idx uint, ep *net.UDPAdd return false } -func (sa *SecurityAssociation) GetEndpointsAndNextSequenceNumber() (seqNum uint32, eps []*net.UDPAddr) { +func (sa *SecurityAssociation) GetEndpointsAndNextSequenceNumber(epsIn []*net.UDPAddr) (seqNum uint32, eps []*net.UDPAddr) { seqNum = atomic.AddUint32(&sa.nextSeqNr, 1) - 1 - eps = make([]*net.UDPAddr, len(sa.endpoints)) + eps = epsIn + if eps == nil { + eps = make([]*net.UDPAddr, len(sa.endpoints)) + } for i := range sa.endpoints { eps[i] = (*net.UDPAddr)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&(sa.endpoints[i]))))) } diff --git a/satp/security-association_test.go b/satp/security-association_test.go index 4bbf7b7..2b1d3cd 100644 --- a/satp/security-association_test.go +++ b/satp/security-association_test.go @@ -159,10 +159,10 @@ func TestSecurityAssociationEndpointUpdate(t *testing.T) { t.Fatalf("endpoints[0] is %v but should be %v", sa.endpoints[0], addr4) } if sa.endpoints[1] != nil { - t.Fatalf("endpoints[1] is %v but should be nil", sa.endpoints[0]) + t.Fatalf("endpoints[1] is %v but should be nil", sa.endpoints[1]) } if !EndpointsEqual(sa.endpoints[2], addr6) { - t.Fatalf("endpoints[2] is %v but should be %v", sa.endpoints[0], addr6) + t.Fatalf("endpoints[2] is %v but should be %v", sa.endpoints[2], addr6) } } @@ -200,3 +200,47 @@ func TestSecurityAssociationEndpointCompareAndUpdate(t *testing.T) { t.Fatalf("updateting endpoint %v with %v should return changed = true", addr4, addr6) } } + +func TestSecurityAssociationGetEndpointsAndNextSequenceNumber(t *testing.T) { + sa := NewSecurityAssociation(nil, 3, 0, 0) + + seq, _ := sa.GetEndpointsAndNextSequenceNumber(nil) + if seq != 0 { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, 0) + } + seq, _ = sa.GetEndpointsAndNextSequenceNumber(nil) + seq, _ = sa.GetEndpointsAndNextSequenceNumber(nil) + seq, _ = sa.GetEndpointsAndNextSequenceNumber(nil) + if seq != 3 { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, 3) + } + + sa = NewSecurityAssociation(nil, 3, (^uint32(0)), 0) + eps := make([]*net.UDPAddr, 3) + for i := range eps { + if eps[i] != nil { + t.Fatalf("endpoints[%d] is %v but should be nil", i) + } + } + seq, _ = sa.GetEndpointsAndNextSequenceNumber(eps) + if seq != (^uint32(0)) { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, (^uint32(0))) + } + addr4, _ := net.ResolveUDPAddr("udp4", "1.2.3.4:444") + sa.EndpointUpdate(0, addr4) + addr6, _ := net.ResolveUDPAddr("udp6", "[2a01:1234::2]:666") + sa.EndpointUpdate(2, addr6) + seq, _ = sa.GetEndpointsAndNextSequenceNumber(eps) + if seq != 0 { + t.Fatalf("next sequnce number returned is %d but should be %d", seq, 0) + } + if !EndpointsEqual(sa.endpoints[0], addr4) { + t.Fatalf("endpoints[0] is %v but should be %v", sa.endpoints[0], addr4) + } + if sa.endpoints[1] != nil { + t.Fatalf("endpoints[1] is %v but should be nil", sa.endpoints[1]) + } + if !EndpointsEqual(sa.endpoints[2], addr6) { + t.Fatalf("endpoints[2] is %v but should be %v", sa.endpoints[2], addr6) + } +} -- cgit v1.2.3