summaryrefslogtreecommitdiffhomepage
path: root/device/allowedips_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips_test.go')
-rw-r--r--device/allowedips_test.go260
1 files changed, 260 insertions, 0 deletions
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
new file mode 100644
index 0000000..075ff06
--- /dev/null
+++ b/device/allowedips_test.go
@@ -0,0 +1,260 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "net"
+ "testing"
+)
+
+/* Todo: More comprehensive
+ */
+
+type testPairCommonBits struct {
+ s1 []byte
+ s2 []byte
+ match uint
+}
+
+type testPairTrieInsert struct {
+ key []byte
+ cidr uint
+ peer *Peer
+}
+
+type testPairTrieLookup struct {
+ key []byte
+ peer *Peer
+}
+
+func printTrie(t *testing.T, p *trieEntry) {
+ if p == nil {
+ return
+ }
+ t.Log(p)
+ printTrie(t, p.child[0])
+ printTrie(t, p.child[1])
+}
+
+func TestCommonBits(t *testing.T) {
+
+ tests := []testPairCommonBits{
+ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
+ {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
+ {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
+ {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
+ {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
+ }
+
+ for _, p := range tests {
+ v := commonBits(p.s1, p.s2)
+ if v != p.match {
+ t.Error(
+ "For slice", p.s1, p.s2,
+ "expected match", p.match,
+ ",but got", v,
+ )
+ }
+ }
+}
+
+func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+ var trie *trieEntry
+ var peers []*Peer
+
+ rand.Seed(1)
+
+ const AddressLength = 4
+
+ for n := 0; n < peerNumber; n += 1 {
+ peers = append(peers, &Peer{})
+ }
+
+ for n := 0; n < addressNumber; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ cidr := uint(rand.Uint32() % (AddressLength * 8))
+ index := rand.Int() % peerNumber
+ trie = trie.insert(addr[:], cidr, peers[index])
+ }
+
+ for n := 0; n < b.N; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ trie.lookup(addr[:])
+ }
+}
+
+func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
+ benchmarkTrie(100, 1000, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
+ benchmarkTrie(10, 10, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
+ benchmarkTrie(100, 1000, net.IPv6len, b)
+}
+
+func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
+ benchmarkTrie(10, 10, net.IPv6len, b)
+}
+
+/* Test ported from kernel implementation:
+ * selftest/allowedips.h
+ */
+func TestTrieIPv4(t *testing.T) {
+ a := &Peer{}
+ b := &Peer{}
+ c := &Peer{}
+ d := &Peer{}
+ e := &Peer{}
+ g := &Peer{}
+ h := &Peer{}
+
+ var trie *trieEntry
+
+ insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
+ trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+ }
+
+ assertEQ := func(peer *Peer, a, b, c, d byte) {
+ p := trie.lookup([]byte{a, b, c, d})
+ if p != peer {
+ t.Error("Assert EQ failed")
+ }
+ }
+
+ assertNEQ := func(peer *Peer, a, b, c, d byte) {
+ p := trie.lookup([]byte{a, b, c, d})
+ if p == peer {
+ t.Error("Assert NEQ failed")
+ }
+ }
+
+ insert(a, 192, 168, 4, 0, 24)
+ insert(b, 192, 168, 4, 4, 32)
+ insert(c, 192, 168, 0, 0, 16)
+ insert(d, 192, 95, 5, 64, 27)
+ insert(c, 192, 95, 5, 65, 27)
+ insert(e, 0, 0, 0, 0, 0)
+ insert(g, 64, 15, 112, 0, 20)
+ insert(h, 64, 15, 123, 211, 25)
+ insert(a, 10, 0, 0, 0, 25)
+ insert(b, 10, 0, 0, 128, 25)
+ insert(a, 10, 1, 0, 0, 30)
+ insert(b, 10, 1, 0, 4, 30)
+ insert(c, 10, 1, 0, 8, 29)
+ insert(d, 10, 1, 0, 16, 29)
+
+ assertEQ(a, 192, 168, 4, 20)
+ assertEQ(a, 192, 168, 4, 0)
+ assertEQ(b, 192, 168, 4, 4)
+ assertEQ(c, 192, 168, 200, 182)
+ assertEQ(c, 192, 95, 5, 68)
+ assertEQ(e, 192, 95, 5, 96)
+ assertEQ(g, 64, 15, 116, 26)
+ assertEQ(g, 64, 15, 127, 3)
+
+ insert(a, 1, 0, 0, 0, 32)
+ insert(a, 64, 0, 0, 0, 32)
+ insert(a, 128, 0, 0, 0, 32)
+ insert(a, 192, 0, 0, 0, 32)
+ insert(a, 255, 0, 0, 0, 32)
+
+ assertEQ(a, 1, 0, 0, 0)
+ assertEQ(a, 64, 0, 0, 0)
+ assertEQ(a, 128, 0, 0, 0)
+ assertEQ(a, 192, 0, 0, 0)
+ assertEQ(a, 255, 0, 0, 0)
+
+ trie = trie.removeByPeer(a)
+
+ assertNEQ(a, 1, 0, 0, 0)
+ assertNEQ(a, 64, 0, 0, 0)
+ assertNEQ(a, 128, 0, 0, 0)
+ assertNEQ(a, 192, 0, 0, 0)
+ assertNEQ(a, 255, 0, 0, 0)
+
+ trie = nil
+
+ insert(a, 192, 168, 0, 0, 16)
+ insert(a, 192, 168, 0, 0, 24)
+
+ trie = trie.removeByPeer(a)
+
+ assertNEQ(a, 192, 168, 0, 1)
+}
+
+/* Test ported from kernel implementation:
+ * selftest/allowedips.h
+ */
+func TestTrieIPv6(t *testing.T) {
+ a := &Peer{}
+ b := &Peer{}
+ c := &Peer{}
+ d := &Peer{}
+ e := &Peer{}
+ f := &Peer{}
+ g := &Peer{}
+ h := &Peer{}
+
+ var trie *trieEntry
+
+ expand := func(a uint32) []byte {
+ var out [4]byte
+ out[0] = byte(a >> 24 & 0xff)
+ out[1] = byte(a >> 16 & 0xff)
+ out[2] = byte(a >> 8 & 0xff)
+ out[3] = byte(a & 0xff)
+ return out[:]
+ }
+
+ insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ trie = trie.insert(addr, cidr, peer)
+ }
+
+ assertEQ := func(peer *Peer, a, b, c, d uint32) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ p := trie.lookup(addr)
+ if p != peer {
+ t.Error("Assert EQ failed")
+ }
+ }
+
+ insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
+ insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
+ insert(e, 0, 0, 0, 0, 0)
+ insert(f, 0, 0, 0, 0, 0)
+ insert(g, 0x24046800, 0, 0, 0, 32)
+ insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
+ insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
+ insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+
+ assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
+ assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
+ assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
+ assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
+ assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
+ assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
+ assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
+ assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
+ assertEQ(h, 0x24046800, 0x40040800, 0, 0)
+ assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
+ assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
+}