diff options
Diffstat (limited to 'ratelimiter/ratelimiter.go')
-rw-r--r-- | ratelimiter/ratelimiter.go | 99 |
1 files changed, 53 insertions, 46 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 772c45a..a6d0ea2 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -20,21 +20,23 @@ const ( ) type RatelimiterEntry struct { - sync.Mutex + mu sync.Mutex lastTime time.Time tokens int64 } type Ratelimiter struct { - sync.RWMutex - stopReset chan struct{} + mu sync.RWMutex + timeNow func() time.Time + + stopReset chan struct{} // send to reset, close to stop tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry } func (rate *Ratelimiter) Close() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() if rate.stopReset != nil { close(rate.stopReset) @@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() { } func (rate *Ratelimiter) Init() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() - // stop any ongoing garbage collection routine + if rate.timeNow == nil { + rate.timeNow = time.Now + } + // stop any ongoing garbage collection routine if rate.stopReset != nil { close(rate.stopReset) } @@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() { rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) - // start garbage collection routine + stopReset := rate.stopReset // store in case Init is called again. + // Start garbage collection routine. go func() { ticker := time.NewTicker(time.Second) ticker.Stop() for { select { - case _, ok := <-rate.stopReset: + case _, ok := <-stopReset: ticker.Stop() - if ok { - ticker = time.NewTicker(time.Second) - } else { + if !ok { return } + ticker = time.NewTicker(time.Second) case <-ticker.C: - func() { - rate.Lock() - defer rate.Unlock() - - for key, entry := range rate.tableIPv4 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) - } - entry.Unlock() - } - - for key, entry := range rate.tableIPv6 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.Unlock() - } - - if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 { - ticker.Stop() - } - }() + if rate.cleanup() { + ticker.Stop() + } } } }() } +func (rate *Ratelimiter) cleanup() (empty bool) { + rate.mu.Lock() + defer rate.mu.Unlock() + + for key, entry := range rate.tableIPv4 { + entry.mu.Lock() + if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv4, key) + } + entry.mu.Unlock() + } + + for key, entry := range rate.tableIPv6 { + entry.mu.Lock() + if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv6, key) + } + entry.mu.Unlock() + } + + return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 +} + func (rate *Ratelimiter) Allow(ip net.IP) bool { var entry *RatelimiterEntry var keyIPv4 [net.IPv4len]byte @@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { IPv4 := ip.To4() IPv6 := ip.To16() - rate.RLock() + rate.mu.RLock() if IPv4 != nil { copy(keyIPv4[:], IPv4) @@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { entry = rate.tableIPv6[keyIPv6] } - rate.RUnlock() + rate.mu.RUnlock() // make new entry if not found if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost - entry.lastTime = time.Now() - rate.Lock() + entry.lastTime = rate.timeNow() + rate.mu.Lock() if IPv4 != nil { rate.tableIPv4[keyIPv4] = entry if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { @@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { rate.stopReset <- struct{}{} } } - rate.Unlock() + rate.mu.Unlock() return true } // add tokens to entry - entry.Lock() - now := time.Now() + entry.mu.Lock() + now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.lastTime = now if entry.tokens > maxTokens { @@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { if entry.tokens > packetCost { entry.tokens -= packetCost - entry.Unlock() + entry.mu.Unlock() return true } - entry.Unlock() + entry.mu.Unlock() return false } |