diff options
Diffstat (limited to 'ratelimiter')
-rw-r--r-- | ratelimiter/ratelimiter.go | 58 | ||||
-rw-r--r-- | ratelimiter/ratelimiter_test.go | 33 |
2 files changed, 29 insertions, 62 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 2f7aa2a..8e78d5e 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "sync" "time" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -30,8 +31,7 @@ type Ratelimiter struct { timeNow func() time.Time stopReset chan struct{} // send to reset, close to stop - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { @@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() { } rate.stopReset = make(chan struct{}) - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.table = make(map[netip.Addr]*RatelimiterEntry) stopReset := rate.stopReset // store in case Init is called again. @@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) { rate.mu.Lock() defer rate.mu.Unlock() - for key, entry := range rate.tableIPv4 { + for key, entry := range rate.table { entry.mu.Lock() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) + delete(rate.table, key) } entry.mu.Unlock() } - for key, entry := range rate.tableIPv6 { - entry.mu.Lock() - if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mu.Unlock() - } - - return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 + return len(rate.table) == 0 } -func (rate *Ratelimiter) Allow(ip net.IP) bool { +func (rate *Ratelimiter) Allow(ip netip.Addr) bool { var entry *RatelimiterEntry - var keyIPv4 [net.IPv4len]byte - var keyIPv6 [net.IPv6len]byte - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - rate.mu.RLock() - - if IPv4 != nil { - copy(keyIPv4[:], IPv4) - entry = rate.tableIPv4[keyIPv4] - } else { - copy(keyIPv6[:], IPv6) - entry = rate.tableIPv6[keyIPv6] - } - + entry = rate.table[ip] rate.mu.RUnlock() // make new entry if not found - if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost entry.lastTime = rate.timeNow() rate.mu.Lock() - if IPv4 != nil { - rate.tableIPv4[keyIPv4] = entry - if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { - rate.stopReset <- struct{}{} - } - } else { - rate.tableIPv6[keyIPv6] = entry - if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { - rate.stopReset <- struct{}{} - } + rate.table[ip] = entry + if len(rate.table) == 1 { + rate.stopReset <- struct{}{} } rate.mu.Unlock() return true } // add tokens to entry - entry.mu.Lock() now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() @@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { } // subtract cost of packet - if entry.tokens > packetCost { entry.tokens -= packetCost entry.mu.Unlock() diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index f231fe5..3e06ff7 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "testing" "time" + + "golang.zx2c4.com/go118/netip" ) type result struct { @@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) { text: "packet following 2 packet burst", }) - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("192.168.1.1"), - net.ParseIP("172.167.2.3"), - net.ParseIP("97.231.252.215"), - net.ParseIP("248.97.91.167"), - net.ParseIP("188.208.233.47"), - net.ParseIP("104.2.183.179"), - net.ParseIP("72.129.46.120"), - net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), - net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), - net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), - net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), - net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), - net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), + ips := []netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("172.167.2.3"), + netip.MustParseAddr("97.231.252.215"), + netip.MustParseAddr("248.97.91.167"), + netip.MustParseAddr("188.208.233.47"), + netip.MustParseAddr("104.2.183.179"), + netip.MustParseAddr("72.129.46.120"), + netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), + netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), + netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), + netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), + netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), + netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } now := time.Now() |