diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-03 16:12:29 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-03 16:29:43 +0200 |
commit | 841756e328c743fec624e9259921ea6d815911d5 (patch) | |
tree | 6b31dc3cd2e7b6e86c38ed6a1b7a36c326ef064d | |
parent | c382222eab9e3814f4df75fd25f8e9e31484b5e0 (diff) |
device: simplify allowedips lookup signature
The inliner should handle this for us.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | device/allowedips.go | 17 | ||||
-rw-r--r-- | device/allowedips_rand_test.go | 4 | ||||
-rw-r--r-- | device/allowedips_test.go | 6 | ||||
-rw-r--r-- | device/receive.go | 4 | ||||
-rw-r--r-- | device/send.go | 4 |
5 files changed, 18 insertions, 17 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index 7af9fc7..95615ab 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -285,14 +285,15 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { } } -func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { +func (table *AllowedIPs) Lookup(address []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - return table.IPv4.lookup(address) -} - -func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.lookup(address) + switch len(address) { + case net.IPv6len: + return table.IPv6.lookup(address) + case net.IPv4len: + return table.IPv4.lookup(address) + default: + panic(errors.New("looking up unknown address type")) + } } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index c5f80fe..8d1e633 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -108,7 +108,7 @@ func TestTrieRandom(t *testing.T) { var addr4 [4]byte rand.Read(addr4[:]) peer1 := slow4.Lookup(addr4[:]) - peer2 := allowedIPs.LookupIPv4(addr4[:]) + peer2 := allowedIPs.Lookup(addr4[:]) if peer1 != peer2 { t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) } @@ -116,7 +116,7 @@ func TestTrieRandom(t *testing.T) { var addr6 [16]byte rand.Read(addr6[:]) peer1 = slow6.Lookup(addr6[:]) - peer2 = allowedIPs.LookupIPv6(addr6[:]) + peer2 = allowedIPs.Lookup(addr6[:]) if peer1 != peer2 { t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index cbd32cc..7701cde 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -102,14 +102,14 @@ func TestTrieIPv4(t *testing.T) { } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := allowedIPs.LookupIPv4([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := allowedIPs.LookupIPv4([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -208,7 +208,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := allowedIPs.LookupIPv6(addr) + p := allowedIPs.Lookup(addr) if p != peer { t.Error("Assert EQ failed") } diff --git a/device/receive.go b/device/receive.go index 1182246..5857481 100644 --- a/device/receive.go +++ b/device/receive.go @@ -447,7 +447,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } elem.packet = elem.packet[:length] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { + if device.allowedips.Lookup(src) != peer { device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) goto skip } @@ -464,7 +464,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } elem.packet = elem.packet[:length] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { + if device.allowedips.Lookup(src) != peer { device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) goto skip } diff --git a/device/send.go b/device/send.go index a4f07e4..b05c69e 100644 --- a/device/send.go +++ b/device/send.go @@ -254,14 +254,14 @@ func (device *Device) RoutineReadFromTUN() { continue } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.LookupIPv4(dst) + peer = device.allowedips.Lookup(dst) case ipv6.Version: if len(elem.packet) < ipv6.HeaderLen { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.LookupIPv6(dst) + peer = device.allowedips.Lookup(dst) default: device.log.Verbosef("Received packet with unknown IP version") |