diff options
-rw-r--r-- | device/allowedips.go | 119 | ||||
-rw-r--r-- | device/allowedips_rand_test.go | 12 | ||||
-rw-r--r-- | device/allowedips_test.go | 23 |
3 files changed, 95 insertions, 59 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index 1564d2d..d613121 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -14,9 +14,15 @@ import ( "unsafe" ) +type parentIndirection struct { + parentBit **trieEntry + parentBitType uint8 +} + type trieEntry struct { peer *Peer child [2]*trieEntry + parent parentIndirection cidr uint8 bitAtByte uint8 bitAtShift uint8 @@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() { } } -func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { - - // at leaf +func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { + for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { + parent = node + if parent.cidr == cidr { + exact = true + return + } + bit := node.choose(ip) + node = node.child[bit] + } + return +} - if node == nil { +func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { + if *trie.parentBit == nil { node := &trieEntry{ - bits: ip, peer: peer, + parent: trie, + bits: ip, cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), } node.maskSelf() node.addToPeerEntries() - return node + *trie.parentBit = node + return } - - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.removeFromPeerEntries() - node.peer = peer - node.addToPeerEntries() - return node - } - bit := node.choose(ip) - node.child[bit] = node.child[bit].insert(ip, cidr, peer) - return node + node, exact := (*trie.parentBit).nodePlacement(ip, cidr) + if exact { + node.removeFromPeerEntries() + node.peer = peer + node.addToPeerEntries() + return } - // split node - newNode := &trieEntry{ - bits: ip, peer: peer, + bits: ip, cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), @@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { newNode.maskSelf() newNode.addToPeerEntries() + var down *trieEntry + if node == nil { + down = *trie.parentBit + } else { + bit := node.choose(ip) + down = node.child[bit] + if down == nil { + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + return + } + } + common := commonBits(down.bits, ip) if common < cidr { cidr = common } - - // check for shorter prefix + parent := node if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode + bit := newNode.choose(down.bits) + down.parent = parentIndirection{&newNode.child[bit], bit} + newNode.child[bit] = down + if parent == nil { + newNode.parent = trie + *trie.parentBit = newNode + } else { + bit := parent.choose(newNode.bits) + newNode.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = newNode + } + return } - // create new parent for node & newNode - - parent := &trieEntry{ - bits: append([]byte{}, ip...), - peer: nil, + node = &trieEntry{ + bits: append([]byte{}, newNode.bits...), cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), } - parent.maskSelf() - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent + node.maskSelf() + + bit := node.choose(down.bits) + down.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = down + bit = node.choose(newNode.bits) + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + if parent == nil { + node.parent = trie + *trie.parentBit = node + } else { + bit := parent.choose(node.bits) + node.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = node + } } func (node *trieEntry) lookup(ip net.IP) *Peer { @@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { switch len(ip) { case net.IPv6len: - table.IPv6 = table.IPv6.insert(ip, cidr, peer) + parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) case net.IPv4len: - table.IPv4 = table.IPv4.insert(ip, cidr, peer) + parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) default: panic(errors.New("inserting unknown address type")) } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 2da8795..48a5bcd 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -65,9 +65,9 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { } func TestTrieRandomIPv4(t *testing.T) { - var trie *trieEntry var slow SlowRouter var peers []*Peer + var allowedIPs AllowedIPs rand.Seed(1) @@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) { rand.Read(addr[:]) cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) + allowedIPs.Insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index]) } @@ -90,7 +90,7 @@ func TestTrieRandomIPv4(t *testing.T) { var addr [AddressLength]byte rand.Read(addr[:]) peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) + peer2 := allowedIPs.LookupIPv4(addr[:]) if peer1 != peer2 { t.Error("Trie did not match naive implementation, for:", addr) } @@ -98,9 +98,9 @@ func TestTrieRandomIPv4(t *testing.T) { } func TestTrieRandomIPv6(t *testing.T) { - var trie *trieEntry var slow SlowRouter var peers []*Peer + var allowedIPs AllowedIPs rand.Seed(1) @@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) { rand.Read(addr[:]) cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) + allowedIPs.Insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index]) } @@ -123,7 +123,7 @@ func TestTrieRandomIPv6(t *testing.T) { var addr [AddressLength]byte rand.Read(addr[:]) peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) + peer2 := allowedIPs.LookupIPv6(addr[:]) if peer1 != peer2 { t.Error("Trie did not match naive implementation, for:", addr) } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 8dc8438..cbd32cc 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -42,6 +42,7 @@ func TestCommonBits(t *testing.T) { func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { var trie *trieEntry var peers []*Peer + root := parentIndirection{&trie, 2} rand.Seed(1) @@ -56,7 +57,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test rand.Read(addr[:]) cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber - trie = trie.insert(addr[:], cidr, peers[index]) + root.insert(addr[:], cidr, peers[index]) } for n := 0; n < b.N; n++ { @@ -94,21 +95,21 @@ func TestTrieIPv4(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { - trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.LookupIPv4([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.LookupIPv4([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -150,7 +151,7 @@ func TestTrieIPv4(t *testing.T) { assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) @@ -158,12 +159,12 @@ func TestTrieIPv4(t *testing.T) { assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) - trie = nil + allowedIPs = AllowedIPs{} insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) } @@ -181,7 +182,7 @@ func TestTrieIPv6(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs expand := func(a uint32) []byte { var out [4]byte @@ -198,7 +199,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - trie = trie.insert(addr, cidr, peer) + allowedIPs.Insert(addr, cidr, peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { @@ -207,7 +208,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := trie.lookup(addr) + p := allowedIPs.LookupIPv6(addr) if p != peer { t.Error("Assert EQ failed") } |