diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-03 15:40:09 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-03 16:29:43 +0200 |
commit | c382222eab9e3814f4df75fd25f8e9e31484b5e0 (patch) | |
tree | 910b69829baae426668c82c83314dcdd9b208437 | |
parent | b41f4cc768021d68b98fed6ca76e7d20fcc38120 (diff) |
device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | device/allowedips.go | 58 | ||||
-rw-r--r-- | device/allowedips_rand_test.go | 96 |
2 files changed, 82 insertions, 72 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index d613121..7af9fc7 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + node := elem.Value.(*trieEntry) + + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + continue + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + continue + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + continue + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + } } func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 48a5bcd..c5f80fe 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -7,6 +7,7 @@ package device import ( "math/rand" + "net" "sort" "testing" ) @@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { return nil } -func TestTrieRandomIPv4(t *testing.T) { - var slow SlowRouter - var peers []*Peer - var allowedIPs AllowedIPs - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n++ { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - allowedIPs.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := allowedIPs.LookupIPv4(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) +func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { + n := 0 + for _, x := range r { + if x.peer != peer { + r[n] = x + n++ } } + return r[:n] } -func TestTrieRandomIPv6(t *testing.T) { - var slow SlowRouter +func TestTrieRandom(t *testing.T) { + var slow4, slow6 SlowRouter var peers []*Peer var allowedIPs AllowedIPs rand.Seed(1) - const AddressLength = 16 - for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) } for n := 0; n < NumberOfAddresses; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - allowedIPs.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) + var addr4 [4]byte + rand.Read(addr4[:]) + cidr := uint8(rand.Intn(32) + 1) + index := rand.Intn(NumberOfPeers) + allowedIPs.Insert(addr4[:], cidr, peers[index]) + slow4 = slow4.Insert(addr4[:], cidr, peers[index]) + + var addr6 [16]byte + rand.Read(addr6[:]) + cidr = uint8(rand.Intn(128) + 1) + index = rand.Intn(NumberOfPeers) + allowedIPs.Insert(addr6[:], cidr, peers[index]) + slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } - for n := 0; n < NumberOfTests; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := allowedIPs.LookupIPv6(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + for p := 0; ; p++ { + for n := 0; n < NumberOfTests; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + peer1 := slow4.Lookup(addr4[:]) + peer2 := allowedIPs.LookupIPv4(addr4[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) + } + + var addr6 [16]byte + rand.Read(addr6[:]) + peer1 = slow6.Lookup(addr6[:]) + peer2 = allowedIPs.LookupIPv6(addr6[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) + } + } + if p >= len(peers) { + break } + allowedIPs.RemoveByPeer(peers[p]) + slow4 = slow4.RemoveByPeer(peers[p]) + slow6 = slow6.RemoveByPeer(peers[p]) + } + + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Failed to remove all nodes from trie by peer") } } |