summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--satp/security-association.go7
-rw-r--r--satp/security-association_test.go112
2 files changed, 111 insertions, 8 deletions
diff --git a/satp/security-association.go b/satp/security-association.go
index b004f47..f20e138 100644
--- a/satp/security-association.go
+++ b/satp/security-association.go
@@ -57,14 +57,14 @@ func (sa *SecurityAssociation) KeyGenerate(dir Direction, usage KeyUsage, sequen
func (sa *SecurityAssociation) EndpointUpdate(idx uint, ep *net.UDPAddr) {
if idx >= uint(len(sa.endpoints)) {
- return
+ return // panic???
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&(sa.endpoints[idx]))), unsafe.Pointer(ep))
}
func (sa *SecurityAssociation) EndpointCompareAndUpdate(idx uint, ep *net.UDPAddr) bool {
if idx >= uint(len(sa.endpoints)) {
- return false
+ return false // panic???
}
old := (*net.UDPAddr)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&(sa.endpoints[idx])))))
if !EndpointsEqual(old, ep) {
@@ -96,9 +96,6 @@ func (sa *SecurityAssociation) SequenceNumberCheckAndSet(senderID uint16, sequen
func NewSecurityAssociation(kd KeyDerivation, numEndpoints uint, initialSeqNrOutbound, initialSeqNrInbound uint32) (sa *SecurityAssociation) {
sa = &SecurityAssociation{kd: kd}
sa.endpoints = make([]*net.UDPAddr, numEndpoints)
- for i := range sa.endpoints {
- sa.endpoints[i] = &net.UDPAddr{}
- }
sa.nextSeqNr = initialSeqNrOutbound
sa.initialSeqNrInbound = initialSeqNrInbound
sa.seqWins = &sync.Map{}
diff --git a/satp/security-association_test.go b/satp/security-association_test.go
index 553510c..d2bdb35 100644
--- a/satp/security-association_test.go
+++ b/satp/security-association_test.go
@@ -31,12 +31,118 @@
package satp
import (
+ "crypto/rand"
+ "net"
"testing"
)
func TestSecurityAssociationNew(t *testing.T) {
- sa := NewSecurityAssociation(nil, 0, 0, 0)
- if sa.endpoints == nil {
- t.Fatal("endpoints must not be nil")
+ testvectors := []struct {
+ numEndpoints uint
+ initSeqOut uint32
+ initSeqIn uint32
+ }{
+ {0, 0, 0},
+ {1, 0, 0},
+ {3, 12, 0},
+ {17, 0, 144},
+ {17, 124, 144},
+ }
+
+ for _, vector := range testvectors {
+ sa := NewSecurityAssociation(nil, vector.numEndpoints, vector.initSeqOut, vector.initSeqIn)
+ if sa == nil {
+ t.Fatal("NewSecurityAssociation returned nil")
+ }
+ if sa.endpoints == nil {
+ t.Fatal("endpoints must not be nil")
+ }
+ if uint(len(sa.endpoints)) != vector.numEndpoints {
+ t.Fatalf("wrong number of endpoints is %d but should be %d", len(sa.endpoints), vector.numEndpoints)
+ }
+ }
+}
+
+func TestSecurityAssociationGenerate(t *testing.T) {
+ var keymat [46]byte
+ rand.Read(keymat[:])
+
+ kd, err := NewAESCTRKeyDerivation(keymat[:32], keymat[32:], RoleLeft)
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ sa := NewSecurityAssociation(kd, 1, 0, 0)
+
+ var out [32]byte
+ err = sa.KeyGenerate(Outbound, UsageEncryptKey, 23, out[:32])
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+ err = sa.KeyGenerate(Outbound, UsageEncryptSalt, 23, out[:14])
+ if err != nil {
+ t.Fatal("unexpected error:", err)
}
}
+
+func TestSecurityAssociationEndpointsEqual(t *testing.T) {
+ local4Num444, _ := net.ResolveUDPAddr("udp4", "127.0.0.1:444")
+ local4Name444, _ := net.ResolveUDPAddr("udp4", "localhost:444")
+ local4Num1234, _ := net.ResolveUDPAddr("udp4", "127.0.0.1:1234")
+ global4Num1234, _ := net.ResolveUDPAddr("udp4", "1.2.3.4:1234")
+ global4Num444, _ := net.ResolveUDPAddr("udp4", "1.2.3.4:444")
+ local6Num666, _ := net.ResolveUDPAddr("udp6", "[::1]:666")
+ local6Name666, _ := net.ResolveUDPAddr("udp6", "localhost:666")
+ local6Name1234, _ := net.ResolveUDPAddr("udp6", "localhost:1234")
+ global6Num666, _ := net.ResolveUDPAddr("udp6", "[2a02::1234:1]:666")
+ global6Num1234, _ := net.ResolveUDPAddr("udp6", "[2a02::1234:1]:1234")
+
+ testvectors := []struct {
+ a, b *net.UDPAddr
+ equal bool
+ }{
+ {nil, nil, true},
+ {local4Num444, nil, false},
+ {local4Num444, local4Num444, true},
+ {local4Name444, local4Num444, true},
+ {local4Name444, local4Num1234, false},
+ {global4Num1234, local4Num1234, false},
+ {global4Num1234, global4Num1234, true},
+ {global4Num1234, global4Num444, false},
+ {local6Num666, local6Num666, true},
+ {local6Name666, local6Num666, true},
+ {local6Name666, local6Name1234, false},
+ {global6Num1234, local6Name1234, false},
+ {global6Num1234, global6Num1234, true},
+ {global6Num1234, global6Num666, false},
+ {local4Num1234, local6Name1234, false},
+ }
+
+ for _, vector := range testvectors {
+ if vector.equal {
+ if !EndpointsEqual(vector.a, vector.b) {
+ t.Fatalf("endpoints %v and %v should be equal but aren't", vector.a, vector.b)
+ }
+ } else {
+ if EndpointsEqual(vector.a, vector.b) {
+ t.Fatalf("endpoints %v and %v shouldn't be equal but are", vector.a, vector.b)
+ }
+ }
+ }
+}
+
+// func TestSecurityAssociationEndpointUpdate(t *testing.T) {
+// addr4, err := net.ResolveUDPAddr("udp4", "1.2.3.4:444")
+// if err != nil {
+// t.Fatal("unexpected error:", err)
+// }
+
+// addr6, err := net.ResolveUDPAddr("udp6", "[2a02::1]:666")
+// if err != nil {
+// t.Fatal("unexpected error:", err)
+// }
+// sa := NewSecurityAssociation(nil, 3, 0, 0)
+// sa.EndpointUpdate(0, addr4)
+// sa.EndpointUpdate(2, addr6)
+// t.Logf("%v", sa.endpoints)
+// }