diff options
Diffstat (limited to 'trie.go')
-rw-r--r-- | trie.go | 233 |
1 files changed, 0 insertions, 233 deletions
diff --git a/trie.go b/trie.go deleted file mode 100644 index 03f0722..0000000 --- a/trie.go +++ /dev/null @@ -1,233 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. - */ - -package main - -import ( - "errors" - "net" -) - -/* Binary trie - * - * The net.IPs used here are not formatted the - * same way as those created by the "net" functions. - * Here the IPs are slices of either 4 or 16 byte (not always 16) - * - * Synchronization done separately - * See: routing.go - */ - -type Trie struct { - cidr uint - child [2]*Trie - bits []byte - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -/* Finds length of matching prefix - * - * TODO: Only use during insertion (xor + prefix mask for lookup) - * Check out - * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) - * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match - * - * Assumption: - * len(ip1) == len(ip2) - * len(ip1) mod 4 = 0 - */ -func commonBits(ip1 []byte, ip2 []byte) uint { - var i uint - size := uint(len(ip1)) - - for i = 0; i < size; i++ { - v := ip1[i] ^ ip2[i] - if v != 0 { - v >>= 1 - if v == 0 { - return i*8 + 7 - } - - v >>= 1 - if v == 0 { - return i*8 + 6 - } - - v >>= 1 - if v == 0 { - return i*8 + 5 - } - - v >>= 1 - if v == 0 { - return i*8 + 4 - } - - v >>= 1 - if v == 0 { - return i*8 + 3 - } - - v >>= 1 - if v == 0 { - return i*8 + 2 - } - - v >>= 1 - if v == 0 { - return i*8 + 1 - } - return i * 8 - } - } - return i * 8 -} - -func (node *Trie) RemovePeer(p *Peer) *Trie { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].RemovePeer(p) - node.child[1] = node.child[1].RemovePeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - -func (node *Trie) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 -} - -func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { - - // at leaf - - if node == nil { - return &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - } - - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node - } - bit := node.choose(ip) - node.child[bit] = node.child[bit].Insert(ip, cidr, peer) - return node - } - - // split node - - newNode := &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - cidr = min(cidr, common) - - // check for shorter prefix - - if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode - } - - // create new parent for node & newNode - - parent := &Trie{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent -} - -func (node *Trie) Lookup(ip net.IP) *Peer { - var found *Peer - size := uint(len(ip)) - for node != nil && commonBits(node.bits, ip) >= node.cidr { - if node.peer != nil { - found = node.peer - } - if node.bit_at_byte == size { - break - } - bit := node.choose(ip) - node = node.child[bit] - } - return found -} - -func (node *Trie) Count() uint { - if node == nil { - return 0 - } - l := node.child[0].Count() - r := node.child[1].Count() - return l + r -} - -func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - var mask net.IPNet - mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) - if len(node.bits) == net.IPv4len { - mask.IP = net.IPv4( - node.bits[0], - node.bits[1], - node.bits[2], - node.bits[3], - ) - } else if len(node.bits) == net.IPv6len { - mask.IP = node.bits - } else { - panic(errors.New("bug: unexpected address length")) - } - results = append(results, mask) - } - results = node.child[0].AllowedIPs(p, results) - results = node.child[1].AllowedIPs(p, results) - return results -} |