diff options
Diffstat (limited to 'device/allowedips.go')
-rw-r--r-- | device/allowedips.go | 68 |
1 files changed, 35 insertions, 33 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index b6f096a..1564d2d 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -15,13 +15,13 @@ import ( ) type trieEntry struct { - child [2]*trieEntry - peer *Peer - bits net.IP - cidr uint - bit_at_byte uint - bit_at_shift uint - perPeerElem *list.Element + peer *Peer + child [2]*trieEntry + cidr uint8 + bitAtByte uint8 + bitAtShift uint8 + bits net.IP + perPeerElem *list.Element } func isLittleEndian() bool { @@ -45,24 +45,24 @@ func swapU64(i uint64) uint64 { return bits.ReverseBytes64(i) } -func commonBits(ip1 net.IP, ip2 net.IP) uint { +func commonBits(ip1 net.IP, ip2 net.IP) uint8 { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) b := (*uint32)(unsafe.Pointer(&ip2[0])) x := *a ^ *b - return uint(bits.LeadingZeros32(swapU32(x))) + return uint8(bits.LeadingZeros32(swapU32(x))) } else if size == net.IPv6len { a := (*uint64)(unsafe.Pointer(&ip1[0])) b := (*uint64)(unsafe.Pointer(&ip2[0])) x := *a ^ *b if x != 0 { - return uint(bits.LeadingZeros64(swapU64(x))) + return uint8(bits.LeadingZeros64(swapU64(x))) } a = (*uint64)(unsafe.Pointer(&ip1[8])) b = (*uint64)(unsafe.Pointer(&ip2[8])) x = *a ^ *b - return 64 + uint(bits.LeadingZeros64(swapU64(x))) + return 64 + uint8(bits.LeadingZeros64(swapU64(x))) } else { panic("Wrong size bit string") } @@ -104,7 +104,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { } func (node *trieEntry) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 + return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } func (node *trieEntry) maskSelf() { @@ -114,17 +114,17 @@ func (node *trieEntry) maskSelf() { } } -func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { +func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { // at leaf if node == nil { node := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + bits: ip, + peer: peer, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } node.maskSelf() node.addToPeerEntries() @@ -149,16 +149,18 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // split node newNode := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + bits: ip, + peer: peer, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } newNode.maskSelf() newNode.addToPeerEntries() - cidr = min(cidr, common) + if common < cidr { + cidr = common + } // check for shorter prefix @@ -171,11 +173,11 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // create new parent for node & newNode parent := &trieEntry{ - bits: append([]byte{}, ip...), - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + bits: append([]byte{}, ip...), + peer: nil, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } parent.maskSelf() @@ -188,12 +190,12 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { func (node *trieEntry) lookup(ip net.IP) *Peer { var found *Peer - size := uint(len(ip)) + size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } - if node.bit_at_byte == size { + if node.bitAtByte == size { break } bit := node.choose(ip) @@ -208,7 +210,7 @@ type AllowedIPs struct { mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { table.mutex.RLock() defer table.mutex.RUnlock() @@ -228,7 +230,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.IPv6 = table.IPv6.removeByPeer(peer) } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { +func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() |