diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-11-05 01:52:54 +0100 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-11-23 22:03:15 +0100 |
commit | ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73 (patch) | |
tree | 5b4a3b53dfb092f10cf11fbe0b5724f58df3a1bf /ratelimiter | |
parent | de7c702ace45b8eeba7f4de8ecd9c85c80806264 (diff) |
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'ratelimiter')
-rw-r--r-- | ratelimiter/ratelimiter.go | 58 | ||||
-rw-r--r-- | ratelimiter/ratelimiter_test.go | 33 |
2 files changed, 29 insertions, 62 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 2f7aa2a..8e78d5e 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "sync" "time" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -30,8 +31,7 @@ type Ratelimiter struct { timeNow func() time.Time stopReset chan struct{} // send to reset, close to stop - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { @@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() { } rate.stopReset = make(chan struct{}) - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.table = make(map[netip.Addr]*RatelimiterEntry) stopReset := rate.stopReset // store in case Init is called again. @@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) { rate.mu.Lock() defer rate.mu.Unlock() - for key, entry := range rate.tableIPv4 { + for key, entry := range rate.table { entry.mu.Lock() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) + delete(rate.table, key) } entry.mu.Unlock() } - for key, entry := range rate.tableIPv6 { - entry.mu.Lock() - if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mu.Unlock() - } - - return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 + return len(rate.table) == 0 } -func (rate *Ratelimiter) Allow(ip net.IP) bool { +func (rate *Ratelimiter) Allow(ip netip.Addr) bool { var entry *RatelimiterEntry - var keyIPv4 [net.IPv4len]byte - var keyIPv6 [net.IPv6len]byte - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - rate.mu.RLock() - - if IPv4 != nil { - copy(keyIPv4[:], IPv4) - entry = rate.tableIPv4[keyIPv4] - } else { - copy(keyIPv6[:], IPv6) - entry = rate.tableIPv6[keyIPv6] - } - + entry = rate.table[ip] rate.mu.RUnlock() // make new entry if not found - if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost entry.lastTime = rate.timeNow() rate.mu.Lock() - if IPv4 != nil { - rate.tableIPv4[keyIPv4] = entry - if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { - rate.stopReset <- struct{}{} - } - } else { - rate.tableIPv6[keyIPv6] = entry - if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { - rate.stopReset <- struct{}{} - } + rate.table[ip] = entry + if len(rate.table) == 1 { + rate.stopReset <- struct{}{} } rate.mu.Unlock() return true } // add tokens to entry - entry.mu.Lock() now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() @@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { } // subtract cost of packet - if entry.tokens > packetCost { entry.tokens -= packetCost entry.mu.Unlock() diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index f231fe5..3e06ff7 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "testing" "time" + + "golang.zx2c4.com/go118/netip" ) type result struct { @@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) { text: "packet following 2 packet burst", }) - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("192.168.1.1"), - net.ParseIP("172.167.2.3"), - net.ParseIP("97.231.252.215"), - net.ParseIP("248.97.91.167"), - net.ParseIP("188.208.233.47"), - net.ParseIP("104.2.183.179"), - net.ParseIP("72.129.46.120"), - net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), - net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), - net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), - net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), - net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), - net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), + ips := []netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("172.167.2.3"), + netip.MustParseAddr("97.231.252.215"), + netip.MustParseAddr("248.97.91.167"), + netip.MustParseAddr("188.208.233.47"), + netip.MustParseAddr("104.2.183.179"), + netip.MustParseAddr("72.129.46.120"), + netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), + netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), + netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), + netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), + netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), + netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } now := time.Now() |