diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-05-13 19:33:41 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-05-13 19:34:28 +0200 |
commit | 2326d6a4d75f9f3736046cc526eb593a403d4c7a (patch) | |
tree | da31622899bb84e256b22addd1604b8250582c44 /allowedips.go | |
parent | e94185681f4c19874392019b0d43bdfefcbb85b5 (diff) |
Odds and ends
Diffstat (limited to 'allowedips.go')
-rw-r--r-- | allowedips.go | 273 |
1 files changed, 273 insertions, 0 deletions
diff --git a/allowedips.go b/allowedips.go new file mode 100644 index 0000000..df53abf --- /dev/null +++ b/allowedips.go @@ -0,0 +1,273 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + */ + +package main + +import ( + "errors" + "net" + "sync" +) + +type trieEntry struct { + cidr uint + child [2]*trieEntry + bits []byte + peer *Peer + + // index of "branching" bit + + bit_at_byte uint + bit_at_shift uint +} + +/* Finds length of matching prefix + * + * TODO: Only use during insertion (xor + prefix mask for lookup) + * Check out + * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) + * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match + * + * Assumption: + * len(ip1) == len(ip2) + * len(ip1) mod 4 = 0 + */ +func commonBits(ip1 []byte, ip2 []byte) uint { + var i uint + size := uint(len(ip1)) + + for i = 0; i < size; i++ { + v := ip1[i] ^ ip2[i] + if v != 0 { + v >>= 1 + if v == 0 { + return i*8 + 7 + } + + v >>= 1 + if v == 0 { + return i*8 + 6 + } + + v >>= 1 + if v == 0 { + return i*8 + 5 + } + + v >>= 1 + if v == 0 { + return i*8 + 4 + } + + v >>= 1 + if v == 0 { + return i*8 + 3 + } + + v >>= 1 + if v == 0 { + return i*8 + 2 + } + + v >>= 1 + if v == 0 { + return i*8 + 1 + } + return i * 8 + } + } + return i * 8 +} + +func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { + if node == nil { + return node + } + + // walk recursively + + node.child[0] = node.child[0].removeByPeer(p) + node.child[1] = node.child[1].removeByPeer(p) + + if node.peer != p { + return node + } + + // remove peer & merge + + node.peer = nil + if node.child[0] == nil { + return node.child[1] + } + return node.child[0] +} + +func (node *trieEntry) choose(ip net.IP) byte { + return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +} + +func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { + + // at leaf + + if node == nil { + return &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + } + + // traverse deeper + + common := commonBits(node.bits, ip) + if node.cidr <= cidr && common >= node.cidr { + if node.cidr == cidr { + node.peer = peer + return node + } + bit := node.choose(ip) + node.child[bit] = node.child[bit].insert(ip, cidr, peer) + return node + } + + // split node + + newNode := &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + cidr = min(cidr, common) + + // check for shorter prefix + + if newNode.cidr == cidr { + bit := newNode.choose(node.bits) + newNode.child[bit] = node + return newNode + } + + // create new parent for node & newNode + + parent := &trieEntry{ + bits: ip, + peer: nil, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + bit := parent.choose(ip) + parent.child[bit] = newNode + parent.child[bit^1] = node + + return parent +} + +func (node *trieEntry) lookup(ip net.IP) *Peer { + var found *Peer + size := uint(len(ip)) + for node != nil && commonBits(node.bits, ip) >= node.cidr { + if node.peer != nil { + found = node.peer + } + if node.bit_at_byte == size { + break + } + bit := node.choose(ip) + node = node.child[bit] + } + return found +} + +func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { + if node == nil { + return results + } + if node.peer == p { + var mask net.IPNet + mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) + if len(node.bits) == net.IPv4len { + mask.IP = net.IPv4( + node.bits[0], + node.bits[1], + node.bits[2], + node.bits[3], + ) + } else if len(node.bits) == net.IPv6len { + mask.IP = node.bits + } else { + panic(errors.New("unexpected address length")) + } + results = append(results, mask) + } + results = node.child[0].entriesForPeer(p, results) + results = node.child[1].entriesForPeer(p, results) + return results +} + +type AllowedIPs struct { + IPv4 *trieEntry + IPv6 *trieEntry + mutex sync.RWMutex +} + +func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.entriesForPeer(peer, allowed) + allowed = table.IPv6.entriesForPeer(peer, allowed) + return allowed +} + +func (table *AllowedIPs) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = nil + table.IPv6 = nil +} + +func (table *AllowedIPs) RemoveByPeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = table.IPv4.removeByPeer(peer) + table.IPv6 = table.IPv6.removeByPeer(peer) +} + +func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.insert(ip, cidr, peer) + default: + panic(errors.New("inserting unknown address type")) + } +} + +func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.lookup(address) +} + +func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.lookup(address) +} |