summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 13:51:03 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 13:51:03 +0200
commit4a57024b94edf23a20f1e4289052d0717227683b (patch)
tree1449ee35b5f8e8d585504547cb855acc78153dff
parent64cb82f2b3f5207f025a1c7ddf4d3043887d5712 (diff)
device: reduce size of trie struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--device/allowedips.go68
-rw-r--r--device/allowedips_rand_test.go8
-rw-r--r--device/allowedips_test.go11
-rw-r--r--device/misc.go7
-rw-r--r--device/uapi.go4
5 files changed, 45 insertions, 53 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index b6f096a..1564d2d 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -15,13 +15,13 @@ import (
)
type trieEntry struct {
- child [2]*trieEntry
- peer *Peer
- bits net.IP
- cidr uint
- bit_at_byte uint
- bit_at_shift uint
- perPeerElem *list.Element
+ peer *Peer
+ child [2]*trieEntry
+ cidr uint8
+ bitAtByte uint8
+ bitAtShift uint8
+ bits net.IP
+ perPeerElem *list.Element
}
func isLittleEndian() bool {
@@ -45,24 +45,24 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i)
}
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
+func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
- return uint(bits.LeadingZeros32(swapU32(x)))
+ return uint8(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
if x != 0 {
- return uint(bits.LeadingZeros64(swapU64(x)))
+ return uint8(bits.LeadingZeros64(swapU64(x)))
}
a = (*uint64)(unsafe.Pointer(&ip1[8]))
b = (*uint64)(unsafe.Pointer(&ip2[8]))
x = *a ^ *b
- return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+ return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
} else {
panic("Wrong size bit string")
}
@@ -104,7 +104,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
}
func (node *trieEntry) choose(ip net.IP) byte {
- return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
+ return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
func (node *trieEntry) maskSelf() {
@@ -114,17 +114,17 @@ func (node *trieEntry) maskSelf() {
}
}
-func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
+func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
// at leaf
if node == nil {
node := &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ bits: ip,
+ peer: peer,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
node.addToPeerEntries()
@@ -149,16 +149,18 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// split node
newNode := &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ bits: ip,
+ peer: peer,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
newNode.maskSelf()
newNode.addToPeerEntries()
- cidr = min(cidr, common)
+ if common < cidr {
+ cidr = common
+ }
// check for shorter prefix
@@ -171,11 +173,11 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// create new parent for node & newNode
parent := &trieEntry{
- bits: append([]byte{}, ip...),
- peer: nil,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ bits: append([]byte{}, ip...),
+ peer: nil,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
parent.maskSelf()
@@ -188,12 +190,12 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer
- size := uint(len(ip))
+ size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
- if node.bit_at_byte == size {
+ if node.bitAtByte == size {
break
}
bit := node.choose(ip)
@@ -208,7 +210,7 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
@@ -228,7 +230,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.IPv6 = table.IPv6.removeByPeer(peer)
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index bb3fb43..2da8795 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -19,7 +19,7 @@ const (
type SlowNode struct {
peer *Peer
- cidr uint
+ cidr uint8
bits []byte
}
@@ -37,7 +37,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
-func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
+func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
@@ -80,7 +80,7 @@ func TestTrieRandomIPv4(t *testing.T) {
for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
+ cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
@@ -113,7 +113,7 @@ func TestTrieRandomIPv6(t *testing.T) {
for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
+ cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index cdd65cf..8dc8438 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -11,13 +11,10 @@ import (
"testing"
)
-/* Todo: More comprehensive
- */
-
type testPairCommonBits struct {
s1 []byte
s2 []byte
- match uint
+ match uint8
}
func TestCommonBits(t *testing.T) {
@@ -57,7 +54,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
+ cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index])
}
@@ -99,7 +96,7 @@ func TestTrieIPv4(t *testing.T) {
var trie *trieEntry
- insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
+ insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
}
@@ -195,7 +192,7 @@ func TestTrieIPv6(t *testing.T) {
return out[:]
}
- insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+ insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
diff --git a/device/misc.go b/device/misc.go
index 2c2510f..4126704 100644
--- a/device/misc.go
+++ b/device/misc.go
@@ -39,10 +39,3 @@ func (a *AtomicBool) Set(val bool) {
}
atomic.StoreInt32(&a.int32, flag)
}
-
-func min(a, b uint) uint {
- if a > b {
- return b
- }
- return a
-}
diff --git a/device/uapi.go b/device/uapi.go
index 659af0a..66ecd48 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -121,7 +121,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
- device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
+ device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr)
return true
})
@@ -379,7 +379,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return nil
}
ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
+ device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
case "protocol_version":
if value != "1" {