summaryrefslogtreecommitdiffhomepage
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/allowedips.go12
-rw-r--r--device/allowedips_test.go11
2 files changed, 22 insertions, 1 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index 95615ab..c08399b 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -96,6 +96,14 @@ func (node *trieEntry) maskSelf() {
}
}
+func (node *trieEntry) zeroizePointers() {
+ // Make the garbage collector's life slightly easier
+ node.peer = nil
+ node.child[0] = nil
+ node.child[1] = nil
+ node.parent.parentBit = nil
+}
+
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
@@ -257,10 +265,12 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
+ node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
@@ -268,6 +278,8 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
child.parent = parent.parent
}
*parent.parent.parentBit = child
+ node.zeroizePointers()
+ parent.zeroizePointers()
}
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 7701cde..2059a88 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -159,7 +159,16 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
- allowedIPs = AllowedIPs{}
+ allowedIPs.RemoveByPeer(a)
+ allowedIPs.RemoveByPeer(b)
+ allowedIPs.RemoveByPeer(c)
+ allowedIPs.RemoveByPeer(d)
+ allowedIPs.RemoveByPeer(e)
+ allowedIPs.RemoveByPeer(g)
+ allowedIPs.RemoveByPeer(h)
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Expected removing all the peers to empty trie, but it did not")
+ }
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)