diff options
Diffstat (limited to 'device/allowedips.go')
-rw-r--r-- | device/allowedips.go | 42 |
1 files changed, 23 insertions, 19 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index c08399b..7a0b275 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -12,6 +12,8 @@ import ( "net" "sync" "unsafe" + + "golang.zx2c4.com/go118/netip" ) type parentIndirection struct { @@ -26,7 +28,7 @@ type trieEntry struct { cidr uint8 bitAtByte uint8 bitAtShift uint8 - bits net.IP + bits []byte perPeerElem *list.Element } @@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 { return bits.ReverseBytes64(i) } -func commonBits(ip1 net.IP, ip2 net.IP) uint8 { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) @@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) choose(ip net.IP) byte { +func (node *trieEntry) choose(ip []byte) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() { node.parent.parentBit = nil } -func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node if parent.cidr == cidr { @@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, return } -func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { if *trie.parentBit == nil { node := &trieEntry{ peer: peer, @@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { } } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -229,13 +231,14 @@ type AllowedIPs struct { mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { node := elem.Value.(*trieEntry) - if !cb(node.bits, node.cidr) { + a, _ := netip.AddrFromSlice(node.bits) + if !cb(netip.PrefixFrom(a, int(node.cidr))) { return } } @@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) - case net.IPv4len: - parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) Lookup(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - switch len(address) { + switch len(ip) { case net.IPv6len: - return table.IPv6.lookup(address) + return table.IPv6.lookup(ip) case net.IPv4len: - return table.IPv4.lookup(address) + return table.IPv4.lookup(ip) default: panic(errors.New("looking up unknown address type")) } |