summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--device/allowedips.go97
-rw-r--r--device/device.go1
-rw-r--r--device/peer.go1
-rw-r--r--device/uapi.go7
4 files changed, 62 insertions, 44 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index 143bda3..b5e40e9 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -14,15 +14,14 @@ import (
)
type trieEntry struct {
- cidr uint
- child [2]*trieEntry
- bits net.IP
- peer *Peer
-
- // index of "branching" bit
-
- bit_at_byte uint
- bit_at_shift uint
+ child [2]*trieEntry
+ peer *Peer
+ bits net.IP
+ cidr uint
+ bit_at_byte uint
+ bit_at_shift uint
+ nextEntryForPeer *trieEntry
+ pprevEntryForPeer **trieEntry
}
func isLittleEndian() bool {
@@ -69,6 +68,31 @@ func commonBits(ip1 net.IP, ip2 net.IP) uint {
}
}
+func (node *trieEntry) addToPeerEntries() {
+ p := node.peer
+ first := p.firstTrieEntry
+ node.nextEntryForPeer = first
+ if first != nil {
+ first.pprevEntryForPeer = &node.nextEntryForPeer
+ }
+ p.firstTrieEntry = node
+ node.pprevEntryForPeer = &p.firstTrieEntry
+}
+
+func (node *trieEntry) removeFromPeerEntries() {
+ if node.pprevEntryForPeer == nil {
+ return
+ }
+ next := node.nextEntryForPeer
+ pprev := node.pprevEntryForPeer
+ *pprev = next
+ if next != nil {
+ next.pprevEntryForPeer = pprev
+ }
+ node.nextEntryForPeer = nil
+ node.pprevEntryForPeer = nil
+}
+
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
@@ -85,6 +109,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
// remove peer & merge
+ node.removeFromPeerEntries()
node.peer = nil
if node.child[0] == nil {
return node.child[1]
@@ -96,18 +121,28 @@ func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}
+func (node *trieEntry) maskSelf() {
+ mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
+ for i := 0; i < len(mask); i++ {
+ node.bits[i] &= mask[i]
+ }
+}
+
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// at leaf
if node == nil {
- return &trieEntry{
+ node := &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
+ node.maskSelf()
+ node.addToPeerEntries()
+ return node
}
// traverse deeper
@@ -115,7 +150,9 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
+ node.removeFromPeerEntries()
node.peer = peer
+ node.addToPeerEntries()
return node
}
bit := node.choose(ip)
@@ -132,6 +169,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
+ newNode.maskSelf()
+ newNode.addToPeerEntries()
cidr = min(cidr, common)
@@ -146,12 +185,13 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// create new parent for node & newNode
parent := &trieEntry{
- bits: ip,
+ bits: append([]byte{}, ip...),
peer: nil,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
+ parent.maskSelf()
bit := parent.choose(ip)
parent.child[bit] = newNode
@@ -176,44 +216,21 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
return found
}
-func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
- if node == nil {
- return results
- }
- if node.peer == p {
- mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
- results = append(results, net.IPNet{
- Mask: mask,
- IP: node.bits.Mask(mask),
- })
- }
- results = node.child[0].entriesForPeer(p, results)
- results = node.child[1].entriesForPeer(p, results)
- return results
-}
-
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
- allowed := make([]net.IPNet, 0, 10)
- allowed = table.IPv4.entriesForPeer(peer, allowed)
- allowed = table.IPv6.entriesForPeer(peer, allowed)
- return allowed
-}
-
-func (table *AllowedIPs) Reset() {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = nil
- table.IPv6 = nil
+ for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer {
+ if !cb(node.bits, node.cidr) {
+ return
+ }
+ }
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
diff --git a/device/device.go b/device/device.go
index ebcbd9e..47c4944 100644
--- a/device/device.go
+++ b/device/device.go
@@ -314,7 +314,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device.rate.underLoadUntil.Store(time.Time{})
device.indexTable.Init()
- device.allowedips.Reset()
device.PopulatePools()
diff --git a/device/peer.go b/device/peer.go
index 5324ae4..a103b5d 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -28,6 +28,7 @@ type Peer struct {
device *Device
endpoint conn.Endpoint
persistentKeepaliveInterval uint32 // accessed atomically
+ firstTrieEntry *trieEntry
// These fields are accessed with atomic operations, which must be
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
diff --git a/device/uapi.go b/device/uapi.go
index 148a7a2..cbfe25e 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -108,9 +108,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
- for _, ip := range device.allowedips.EntriesForPeer(peer) {
- sendf("allowed_ip=%s", ip.String())
- }
+ device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
+ sendf("allowed_ip=%s/%d", ip.String(), cidr)
+ return true
+ })
}
}()