summaryrefslogtreecommitdiffhomepage
path: root/trie.go
diff options
context:
space:
mode:
Diffstat (limited to 'trie.go')
-rw-r--r--trie.go233
1 files changed, 0 insertions, 233 deletions
diff --git a/trie.go b/trie.go
deleted file mode 100644
index 03f0722..0000000
--- a/trie.go
+++ /dev/null
@@ -1,233 +0,0 @@
-/* SPDX-License-Identifier: GPL-2.0
- *
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
- */
-
-package main
-
-import (
- "errors"
- "net"
-)
-
-/* Binary trie
- *
- * The net.IPs used here are not formatted the
- * same way as those created by the "net" functions.
- * Here the IPs are slices of either 4 or 16 byte (not always 16)
- *
- * Synchronization done separately
- * See: routing.go
- */
-
-type Trie struct {
- cidr uint
- child [2]*Trie
- bits []byte
- peer *Peer
-
- // index of "branching" bit
-
- bit_at_byte uint
- bit_at_shift uint
-}
-
-/* Finds length of matching prefix
- *
- * TODO: Only use during insertion (xor + prefix mask for lookup)
- * Check out
- * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
- * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
- *
- * Assumption:
- * len(ip1) == len(ip2)
- * len(ip1) mod 4 = 0
- */
-func commonBits(ip1 []byte, ip2 []byte) uint {
- var i uint
- size := uint(len(ip1))
-
- for i = 0; i < size; i++ {
- v := ip1[i] ^ ip2[i]
- if v != 0 {
- v >>= 1
- if v == 0 {
- return i*8 + 7
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 6
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 5
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 4
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 3
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 2
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 1
- }
- return i * 8
- }
- }
- return i * 8
-}
-
-func (node *Trie) RemovePeer(p *Peer) *Trie {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].RemovePeer(p)
- node.child[1] = node.child[1].RemovePeer(p)
-
- if node.peer != p {
- return node
- }
-
- // remove peer & merge
-
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
- }
- return node.child[0]
-}
-
-func (node *Trie) choose(ip net.IP) byte {
- return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
-}
-
-func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
-
- // at leaf
-
- if node == nil {
- return &Trie{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
- }
- }
-
- // traverse deeper
-
- common := commonBits(node.bits, ip)
- if node.cidr <= cidr && common >= node.cidr {
- if node.cidr == cidr {
- node.peer = peer
- return node
- }
- bit := node.choose(ip)
- node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
- return node
- }
-
- // split node
-
- newNode := &Trie{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
- }
-
- cidr = min(cidr, common)
-
- // check for shorter prefix
-
- if newNode.cidr == cidr {
- bit := newNode.choose(node.bits)
- newNode.child[bit] = node
- return newNode
- }
-
- // create new parent for node & newNode
-
- parent := &Trie{
- bits: ip,
- peer: nil,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
- }
-
- bit := parent.choose(ip)
- parent.child[bit] = newNode
- parent.child[bit^1] = node
-
- return parent
-}
-
-func (node *Trie) Lookup(ip net.IP) *Peer {
- var found *Peer
- size := uint(len(ip))
- for node != nil && commonBits(node.bits, ip) >= node.cidr {
- if node.peer != nil {
- found = node.peer
- }
- if node.bit_at_byte == size {
- break
- }
- bit := node.choose(ip)
- node = node.child[bit]
- }
- return found
-}
-
-func (node *Trie) Count() uint {
- if node == nil {
- return 0
- }
- l := node.child[0].Count()
- r := node.child[1].Count()
- return l + r
-}
-
-func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
- if node == nil {
- return results
- }
- if node.peer == p {
- var mask net.IPNet
- mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
- if len(node.bits) == net.IPv4len {
- mask.IP = net.IPv4(
- node.bits[0],
- node.bits[1],
- node.bits[2],
- node.bits[3],
- )
- } else if len(node.bits) == net.IPv6len {
- mask.IP = node.bits
- } else {
- panic(errors.New("bug: unexpected address length"))
- }
- results = append(results, mask)
- }
- results = node.child[0].AllowedIPs(p, results)
- results = node.child[1].AllowedIPs(p, results)
- return results
-}