diff options
Diffstat (limited to 'device/allowedips.go')
-rw-r--r-- | device/allowedips.go | 119 |
1 files changed, 77 insertions, 42 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index 1564d2d..d613121 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -14,9 +14,15 @@ import ( "unsafe" ) +type parentIndirection struct { + parentBit **trieEntry + parentBitType uint8 +} + type trieEntry struct { peer *Peer child [2]*trieEntry + parent parentIndirection cidr uint8 bitAtByte uint8 bitAtShift uint8 @@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() { } } -func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { - - // at leaf +func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { + for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { + parent = node + if parent.cidr == cidr { + exact = true + return + } + bit := node.choose(ip) + node = node.child[bit] + } + return +} - if node == nil { +func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { + if *trie.parentBit == nil { node := &trieEntry{ - bits: ip, peer: peer, + parent: trie, + bits: ip, cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), } node.maskSelf() node.addToPeerEntries() - return node + *trie.parentBit = node + return } - - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.removeFromPeerEntries() - node.peer = peer - node.addToPeerEntries() - return node - } - bit := node.choose(ip) - node.child[bit] = node.child[bit].insert(ip, cidr, peer) - return node + node, exact := (*trie.parentBit).nodePlacement(ip, cidr) + if exact { + node.removeFromPeerEntries() + node.peer = peer + node.addToPeerEntries() + return } - // split node - newNode := &trieEntry{ - bits: ip, peer: peer, + bits: ip, cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), @@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { newNode.maskSelf() newNode.addToPeerEntries() + var down *trieEntry + if node == nil { + down = *trie.parentBit + } else { + bit := node.choose(ip) + down = node.child[bit] + if down == nil { + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + return + } + } + common := commonBits(down.bits, ip) if common < cidr { cidr = common } - - // check for shorter prefix + parent := node if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode + bit := newNode.choose(down.bits) + down.parent = parentIndirection{&newNode.child[bit], bit} + newNode.child[bit] = down + if parent == nil { + newNode.parent = trie + *trie.parentBit = newNode + } else { + bit := parent.choose(newNode.bits) + newNode.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = newNode + } + return } - // create new parent for node & newNode - - parent := &trieEntry{ - bits: append([]byte{}, ip...), - peer: nil, + node = &trieEntry{ + bits: append([]byte{}, newNode.bits...), cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), } - parent.maskSelf() - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent + node.maskSelf() + + bit := node.choose(down.bits) + down.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = down + bit = node.choose(newNode.bits) + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + if parent == nil { + node.parent = trie + *trie.parentBit = node + } else { + bit := parent.choose(node.bits) + node.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = node + } } func (node *trieEntry) lookup(ip net.IP) *Peer { @@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { switch len(ip) { case net.IPv6len: - table.IPv6 = table.IPv6.insert(ip, cidr, peer) + parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) case net.IPv4len: - table.IPv4 = table.IPv4.insert(ip, cidr, peer) + parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) default: panic(errors.New("inserting unknown address type")) } |