diff options
Diffstat (limited to 'allowedips.go')
-rw-r--r-- | allowedips.go | 96 |
1 files changed, 42 insertions, 54 deletions
diff --git a/allowedips.go b/allowedips.go index df53abf..e700dc4 100644 --- a/allowedips.go +++ b/allowedips.go @@ -7,8 +7,10 @@ package main import ( "errors" + "math/bits" "net" "sync" + "unsafe" ) type trieEntry struct { @@ -23,62 +25,48 @@ type trieEntry struct { bit_at_shift uint } -/* Finds length of matching prefix - * - * TODO: Only use during insertion (xor + prefix mask for lookup) - * Check out - * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) - * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match - * - * Assumption: - * len(ip1) == len(ip2) - * len(ip1) mod 4 = 0 - */ -func commonBits(ip1 []byte, ip2 []byte) uint { - var i uint - size := uint(len(ip1)) - - for i = 0; i < size; i++ { - v := ip1[i] ^ ip2[i] - if v != 0 { - v >>= 1 - if v == 0 { - return i*8 + 7 - } - - v >>= 1 - if v == 0 { - return i*8 + 6 - } - - v >>= 1 - if v == 0 { - return i*8 + 5 - } - - v >>= 1 - if v == 0 { - return i*8 + 4 - } - - v >>= 1 - if v == 0 { - return i*8 + 3 - } - - v >>= 1 - if v == 0 { - return i*8 + 2 - } - - v >>= 1 - if v == 0 { - return i*8 + 1 - } - return i * 8 +func isLittleEndian() bool { + one := uint32(1) + return *(*byte)(unsafe.Pointer(&one)) != 0 +} + +func swapU32(i uint32) uint32 { + if !isLittleEndian() { + return i + } + + return bits.ReverseBytes32(i) +} + +func swapU64(i uint64) uint64 { + if !isLittleEndian() { + return i + } + + return bits.ReverseBytes64(i) +} + +func commonBits(ip1 net.IP, ip2 net.IP) uint { + size := len(ip1) + if size == net.IPv4len { + a := (*uint32)(unsafe.Pointer(&ip1[0])) + b := (*uint32)(unsafe.Pointer(&ip2[0])) + x := *a ^ *b + return uint(bits.LeadingZeros32(swapU32(x))) + } else if size == net.IPv6len { + a := (*uint64)(unsafe.Pointer(&ip1[0])) + b := (*uint64)(unsafe.Pointer(&ip2[0])) + x := *a ^ *b + if x != 0 { + return uint(bits.LeadingZeros64(swapU64(x))) } + a = (*uint64)(unsafe.Pointer(&ip1[8])) + b = (*uint64)(unsafe.Pointer(&ip2[8])) + x = *a ^ *b + return 64 + uint(bits.LeadingZeros64(swapU64(x))) + } else { + panic("Wrong size bit string") } - return i * 8 } func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { |