diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-04 16:33:28 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-04 16:33:28 +0200 |
commit | f9b48a961cd271bcc58c4c76b61a84a139e76167 (patch) | |
tree | 7a8f380838ba5844c04e75269e075ddfe373256d | |
parent | d0cf96114fa60b8f7e7a671a7749538edec9d877 (diff) |
device: zero out allowedip node pointers when removing
This should make it a bit easier for the garbage collector.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | device/allowedips.go | 12 | ||||
-rw-r--r-- | device/allowedips_test.go | 11 |
2 files changed, 22 insertions, 1 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index 95615ab..c08399b 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -96,6 +96,14 @@ func (node *trieEntry) maskSelf() { } } +func (node *trieEntry) zeroizePointers() { + // Make the garbage collector's life slightly easier + node.peer = nil + node.child[0] = nil + node.child[1] = nil + node.parent.parentBit = nil +} + func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node @@ -257,10 +265,12 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { } *node.parent.parentBit = child if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() continue } parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) if parent.peer != nil { + node.zeroizePointers() continue } child = parent.child[node.parent.parentBitType^1] @@ -268,6 +278,8 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { child.parent = parent.parent } *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() } } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 7701cde..2059a88 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -159,7 +159,16 @@ func TestTrieIPv4(t *testing.T) { assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) - allowedIPs = AllowedIPs{} + allowedIPs.RemoveByPeer(a) + allowedIPs.RemoveByPeer(b) + allowedIPs.RemoveByPeer(c) + allowedIPs.RemoveByPeer(d) + allowedIPs.RemoveByPeer(e) + allowedIPs.RemoveByPeer(g) + allowedIPs.RemoveByPeer(h) + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Expected removing all the peers to empty trie, but it did not") + } insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) |