summaryrefslogtreecommitdiffhomepage
path: root/allowedips.go
diff options
context:
space:
mode:
Diffstat (limited to 'allowedips.go')
-rw-r--r--allowedips.go96
1 files changed, 42 insertions, 54 deletions
diff --git a/allowedips.go b/allowedips.go
index df53abf..e700dc4 100644
--- a/allowedips.go
+++ b/allowedips.go
@@ -7,8 +7,10 @@ package main
import (
"errors"
+ "math/bits"
"net"
"sync"
+ "unsafe"
)
type trieEntry struct {
@@ -23,62 +25,48 @@ type trieEntry struct {
bit_at_shift uint
}
-/* Finds length of matching prefix
- *
- * TODO: Only use during insertion (xor + prefix mask for lookup)
- * Check out
- * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
- * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
- *
- * Assumption:
- * len(ip1) == len(ip2)
- * len(ip1) mod 4 = 0
- */
-func commonBits(ip1 []byte, ip2 []byte) uint {
- var i uint
- size := uint(len(ip1))
-
- for i = 0; i < size; i++ {
- v := ip1[i] ^ ip2[i]
- if v != 0 {
- v >>= 1
- if v == 0 {
- return i*8 + 7
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 6
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 5
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 4
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 3
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 2
- }
-
- v >>= 1
- if v == 0 {
- return i*8 + 1
- }
- return i * 8
+func isLittleEndian() bool {
+ one := uint32(1)
+ return *(*byte)(unsafe.Pointer(&one)) != 0
+}
+
+func swapU32(i uint32) uint32 {
+ if !isLittleEndian() {
+ return i
+ }
+
+ return bits.ReverseBytes32(i)
+}
+
+func swapU64(i uint64) uint64 {
+ if !isLittleEndian() {
+ return i
+ }
+
+ return bits.ReverseBytes64(i)
+}
+
+func commonBits(ip1 net.IP, ip2 net.IP) uint {
+ size := len(ip1)
+ if size == net.IPv4len {
+ a := (*uint32)(unsafe.Pointer(&ip1[0]))
+ b := (*uint32)(unsafe.Pointer(&ip2[0]))
+ x := *a ^ *b
+ return uint(bits.LeadingZeros32(swapU32(x)))
+ } else if size == net.IPv6len {
+ a := (*uint64)(unsafe.Pointer(&ip1[0]))
+ b := (*uint64)(unsafe.Pointer(&ip2[0]))
+ x := *a ^ *b
+ if x != 0 {
+ return uint(bits.LeadingZeros64(swapU64(x)))
}
+ a = (*uint64)(unsafe.Pointer(&ip1[8]))
+ b = (*uint64)(unsafe.Pointer(&ip2[8]))
+ x = *a ^ *b
+ return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+ } else {
+ panic("Wrong size bit string")
}
- return i * 8
}
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {