diff options
Diffstat (limited to 'ratelimiter/ratelimiter.go')
-rw-r--r-- | ratelimiter/ratelimiter.go | 58 |
1 files changed, 12 insertions, 46 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 2f7aa2a..8e78d5e 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -6,9 +6,10 @@ package ratelimiter import ( - "net" "sync" "time" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -30,8 +31,7 @@ type Ratelimiter struct { timeNow func() time.Time stopReset chan struct{} // send to reset, close to stop - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { @@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() { } rate.stopReset = make(chan struct{}) - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.table = make(map[netip.Addr]*RatelimiterEntry) stopReset := rate.stopReset // store in case Init is called again. @@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) { rate.mu.Lock() defer rate.mu.Unlock() - for key, entry := range rate.tableIPv4 { + for key, entry := range rate.table { entry.mu.Lock() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) + delete(rate.table, key) } entry.mu.Unlock() } - for key, entry := range rate.tableIPv6 { - entry.mu.Lock() - if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mu.Unlock() - } - - return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 + return len(rate.table) == 0 } -func (rate *Ratelimiter) Allow(ip net.IP) bool { +func (rate *Ratelimiter) Allow(ip netip.Addr) bool { var entry *RatelimiterEntry - var keyIPv4 [net.IPv4len]byte - var keyIPv6 [net.IPv6len]byte - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - rate.mu.RLock() - - if IPv4 != nil { - copy(keyIPv4[:], IPv4) - entry = rate.tableIPv4[keyIPv4] - } else { - copy(keyIPv6[:], IPv6) - entry = rate.tableIPv6[keyIPv6] - } - + entry = rate.table[ip] rate.mu.RUnlock() // make new entry if not found - if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost entry.lastTime = rate.timeNow() rate.mu.Lock() - if IPv4 != nil { - rate.tableIPv4[keyIPv4] = entry - if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { - rate.stopReset <- struct{}{} - } - } else { - rate.tableIPv6[keyIPv6] = entry - if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { - rate.stopReset <- struct{}{} - } + rate.table[ip] = entry + if len(rate.table) == 1 { + rate.stopReset <- struct{}{} } rate.mu.Unlock() return true } // add tokens to entry - entry.mu.Lock() now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() @@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { } // subtract cost of packet - if entry.tokens > packetCost { entry.tokens -= packetCost entry.mu.Unlock() |