From 99ef9725cd769fc1a46675b594e9114bed734c2c Mon Sep 17 00:00:00 2001 From: Christian Pointner Date: Sat, 4 Nov 2017 16:20:21 +0100 Subject: some tests for SA --- satp/security-association.go | 7 +-- satp/security-association_test.go | 112 +++++++++++++++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 8 deletions(-) (limited to 'satp') diff --git a/satp/security-association.go b/satp/security-association.go index b004f47..f20e138 100644 --- a/satp/security-association.go +++ b/satp/security-association.go @@ -57,14 +57,14 @@ func (sa *SecurityAssociation) KeyGenerate(dir Direction, usage KeyUsage, sequen func (sa *SecurityAssociation) EndpointUpdate(idx uint, ep *net.UDPAddr) { if idx >= uint(len(sa.endpoints)) { - return + return // panic??? } atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&(sa.endpoints[idx]))), unsafe.Pointer(ep)) } func (sa *SecurityAssociation) EndpointCompareAndUpdate(idx uint, ep *net.UDPAddr) bool { if idx >= uint(len(sa.endpoints)) { - return false + return false // panic??? } old := (*net.UDPAddr)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&(sa.endpoints[idx]))))) if !EndpointsEqual(old, ep) { @@ -96,9 +96,6 @@ func (sa *SecurityAssociation) SequenceNumberCheckAndSet(senderID uint16, sequen func NewSecurityAssociation(kd KeyDerivation, numEndpoints uint, initialSeqNrOutbound, initialSeqNrInbound uint32) (sa *SecurityAssociation) { sa = &SecurityAssociation{kd: kd} sa.endpoints = make([]*net.UDPAddr, numEndpoints) - for i := range sa.endpoints { - sa.endpoints[i] = &net.UDPAddr{} - } sa.nextSeqNr = initialSeqNrOutbound sa.initialSeqNrInbound = initialSeqNrInbound sa.seqWins = &sync.Map{} diff --git a/satp/security-association_test.go b/satp/security-association_test.go index 553510c..d2bdb35 100644 --- a/satp/security-association_test.go +++ b/satp/security-association_test.go @@ -31,12 +31,118 @@ package satp import ( + "crypto/rand" + "net" "testing" ) func TestSecurityAssociationNew(t *testing.T) { - sa := NewSecurityAssociation(nil, 0, 0, 0) - if sa.endpoints == nil { - t.Fatal("endpoints must not be nil") + testvectors := []struct { + numEndpoints uint + initSeqOut uint32 + initSeqIn uint32 + }{ + {0, 0, 0}, + {1, 0, 0}, + {3, 12, 0}, + {17, 0, 144}, + {17, 124, 144}, + } + + for _, vector := range testvectors { + sa := NewSecurityAssociation(nil, vector.numEndpoints, vector.initSeqOut, vector.initSeqIn) + if sa == nil { + t.Fatal("NewSecurityAssociation returned nil") + } + if sa.endpoints == nil { + t.Fatal("endpoints must not be nil") + } + if uint(len(sa.endpoints)) != vector.numEndpoints { + t.Fatalf("wrong number of endpoints is %d but should be %d", len(sa.endpoints), vector.numEndpoints) + } + } +} + +func TestSecurityAssociationGenerate(t *testing.T) { + var keymat [46]byte + rand.Read(keymat[:]) + + kd, err := NewAESCTRKeyDerivation(keymat[:32], keymat[32:], RoleLeft) + if err != nil { + t.Fatal("unexpected error:", err) + } + + sa := NewSecurityAssociation(kd, 1, 0, 0) + + var out [32]byte + err = sa.KeyGenerate(Outbound, UsageEncryptKey, 23, out[:32]) + if err != nil { + t.Fatal("unexpected error:", err) + } + err = sa.KeyGenerate(Outbound, UsageEncryptSalt, 23, out[:14]) + if err != nil { + t.Fatal("unexpected error:", err) } } + +func TestSecurityAssociationEndpointsEqual(t *testing.T) { + local4Num444, _ := net.ResolveUDPAddr("udp4", "127.0.0.1:444") + local4Name444, _ := net.ResolveUDPAddr("udp4", "localhost:444") + local4Num1234, _ := net.ResolveUDPAddr("udp4", "127.0.0.1:1234") + global4Num1234, _ := net.ResolveUDPAddr("udp4", "1.2.3.4:1234") + global4Num444, _ := net.ResolveUDPAddr("udp4", "1.2.3.4:444") + local6Num666, _ := net.ResolveUDPAddr("udp6", "[::1]:666") + local6Name666, _ := net.ResolveUDPAddr("udp6", "localhost:666") + local6Name1234, _ := net.ResolveUDPAddr("udp6", "localhost:1234") + global6Num666, _ := net.ResolveUDPAddr("udp6", "[2a02::1234:1]:666") + global6Num1234, _ := net.ResolveUDPAddr("udp6", "[2a02::1234:1]:1234") + + testvectors := []struct { + a, b *net.UDPAddr + equal bool + }{ + {nil, nil, true}, + {local4Num444, nil, false}, + {local4Num444, local4Num444, true}, + {local4Name444, local4Num444, true}, + {local4Name444, local4Num1234, false}, + {global4Num1234, local4Num1234, false}, + {global4Num1234, global4Num1234, true}, + {global4Num1234, global4Num444, false}, + {local6Num666, local6Num666, true}, + {local6Name666, local6Num666, true}, + {local6Name666, local6Name1234, false}, + {global6Num1234, local6Name1234, false}, + {global6Num1234, global6Num1234, true}, + {global6Num1234, global6Num666, false}, + {local4Num1234, local6Name1234, false}, + } + + for _, vector := range testvectors { + if vector.equal { + if !EndpointsEqual(vector.a, vector.b) { + t.Fatalf("endpoints %v and %v should be equal but aren't", vector.a, vector.b) + } + } else { + if EndpointsEqual(vector.a, vector.b) { + t.Fatalf("endpoints %v and %v shouldn't be equal but are", vector.a, vector.b) + } + } + } +} + +// func TestSecurityAssociationEndpointUpdate(t *testing.T) { +// addr4, err := net.ResolveUDPAddr("udp4", "1.2.3.4:444") +// if err != nil { +// t.Fatal("unexpected error:", err) +// } + +// addr6, err := net.ResolveUDPAddr("udp6", "[2a02::1]:666") +// if err != nil { +// t.Fatal("unexpected error:", err) +// } +// sa := NewSecurityAssociation(nil, 3, 0, 0) +// sa.EndpointUpdate(0, addr4) +// sa.EndpointUpdate(2, addr6) +// t.Logf("%v", sa.endpoints) +// } -- cgit v1.2.3