summaryrefslogtreecommitdiffhomepage
path: root/ratelimiter
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-11-05 01:52:54 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-11-23 22:03:15 +0100
commitef8d6804d77d9ce09f0e2c7f6d85bbe222712b73 (patch)
tree5b4a3b53dfb092f10cf11fbe0b5724f58df3a1bf /ratelimiter
parentde7c702ace45b8eeba7f4de8ecd9c85c80806264 (diff)
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'ratelimiter')
-rw-r--r--ratelimiter/ratelimiter.go58
-rw-r--r--ratelimiter/ratelimiter_test.go33
2 files changed, 29 insertions, 62 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 2f7aa2a..8e78d5e 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -6,9 +6,10 @@
package ratelimiter
import (
- "net"
"sync"
"time"
+
+ "golang.zx2c4.com/go118/netip"
)
const (
@@ -30,8 +31,7 @@ type Ratelimiter struct {
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
- tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
- tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+ table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
@@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() {
}
rate.stopReset = make(chan struct{})
- rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
- rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ rate.table = make(map[netip.Addr]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again.
@@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
- for key, entry := range rate.tableIPv4 {
+ for key, entry := range rate.table {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
+ delete(rate.table, key)
}
entry.mu.Unlock()
}
- for key, entry := range rate.tableIPv6 {
- entry.mu.Lock()
- if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.mu.Unlock()
- }
-
- return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+ return len(rate.table) == 0
}
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
- var keyIPv4 [net.IPv4len]byte
- var keyIPv6 [net.IPv6len]byte
-
// lookup entry
-
- IPv4 := ip.To4()
- IPv6 := ip.To16()
-
rate.mu.RLock()
-
- if IPv4 != nil {
- copy(keyIPv4[:], IPv4)
- entry = rate.tableIPv4[keyIPv4]
- } else {
- copy(keyIPv6[:], IPv6)
- entry = rate.tableIPv6[keyIPv6]
- }
-
+ entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found
-
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow()
rate.mu.Lock()
- if IPv4 != nil {
- rate.tableIPv4[keyIPv4] = entry
- if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
- rate.stopReset <- struct{}{}
- }
- } else {
- rate.tableIPv6[keyIPv6] = entry
- if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
- rate.stopReset <- struct{}{}
- }
+ rate.table[ip] = entry
+ if len(rate.table) == 1 {
+ rate.stopReset <- struct{}{}
}
rate.mu.Unlock()
return true
}
// add tokens to entry
-
entry.mu.Lock()
now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
-
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.mu.Unlock()
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index f231fe5..3e06ff7 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -6,9 +6,10 @@
package ratelimiter
import (
- "net"
"testing"
"time"
+
+ "golang.zx2c4.com/go118/netip"
)
type result struct {
@@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst",
})
- ips := []net.IP{
- net.ParseIP("127.0.0.1"),
- net.ParseIP("192.168.1.1"),
- net.ParseIP("172.167.2.3"),
- net.ParseIP("97.231.252.215"),
- net.ParseIP("248.97.91.167"),
- net.ParseIP("188.208.233.47"),
- net.ParseIP("104.2.183.179"),
- net.ParseIP("72.129.46.120"),
- net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
- net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
- net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
- net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
- net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
- net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
+ ips := []netip.Addr{
+ netip.MustParseAddr("127.0.0.1"),
+ netip.MustParseAddr("192.168.1.1"),
+ netip.MustParseAddr("172.167.2.3"),
+ netip.MustParseAddr("97.231.252.215"),
+ netip.MustParseAddr("248.97.91.167"),
+ netip.MustParseAddr("188.208.233.47"),
+ netip.MustParseAddr("104.2.183.179"),
+ netip.MustParseAddr("72.129.46.120"),
+ netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
+ netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
+ netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
+ netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
+ netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
+ netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
now := time.Now()