diff options
author | David Crawshaw <crawshaw@tailscale.com> | 2019-12-08 18:22:31 -0500 |
---|---|---|
committer | David Crawshaw <david@zentus.com> | 2020-03-30 18:38:36 +1100 |
commit | 9cd8909df2ad882b81b611b4656020aeceb6c9b2 (patch) | |
tree | 3e09b88bfdf9069d96b4c98ef39aeaea7b5b1d3a /ratelimiter/ratelimiter.go | |
parent | ae88e2a2cda0faab68ad667223cd26ffd54d1bee (diff) |
ratelimiter: use a fake clock in tests and style cleanups
The existing test would occasionally flake out with:
--- FAIL: TestRatelimiter (0.12s)
ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true
FAIL
FAIL golang.zx2c4.com/wireguard/ratelimiter 0.171s
The fake clock also means the tests run much faster, so
testing this package with -count=1000 now takes < 100ms.
While here, several style cleanups. The most significant one
is unembeding the sync.Mutex fields in the rate limiter objects.
Embedded as they were, the lock methods were accessible
outside the ratelimiter package. As they aren't needed externally,
keep them internal to make them easier to reason about.
Passes `go test -race -count=10000 ./ratelimiter`
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
Diffstat (limited to 'ratelimiter/ratelimiter.go')
-rw-r--r-- | ratelimiter/ratelimiter.go | 99 |
1 files changed, 53 insertions, 46 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 772c45a..a6d0ea2 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -20,21 +20,23 @@ const ( ) type RatelimiterEntry struct { - sync.Mutex + mu sync.Mutex lastTime time.Time tokens int64 } type Ratelimiter struct { - sync.RWMutex - stopReset chan struct{} + mu sync.RWMutex + timeNow func() time.Time + + stopReset chan struct{} // send to reset, close to stop tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry } func (rate *Ratelimiter) Close() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() if rate.stopReset != nil { close(rate.stopReset) @@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() { } func (rate *Ratelimiter) Init() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() - // stop any ongoing garbage collection routine + if rate.timeNow == nil { + rate.timeNow = time.Now + } + // stop any ongoing garbage collection routine if rate.stopReset != nil { close(rate.stopReset) } @@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() { rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) - // start garbage collection routine + stopReset := rate.stopReset // store in case Init is called again. + // Start garbage collection routine. go func() { ticker := time.NewTicker(time.Second) ticker.Stop() for { select { - case _, ok := <-rate.stopReset: + case _, ok := <-stopReset: ticker.Stop() - if ok { - ticker = time.NewTicker(time.Second) - } else { + if !ok { return } + ticker = time.NewTicker(time.Second) case <-ticker.C: - func() { - rate.Lock() - defer rate.Unlock() - - for key, entry := range rate.tableIPv4 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) - } - entry.Unlock() - } - - for key, entry := range rate.tableIPv6 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.Unlock() - } - - if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 { - ticker.Stop() - } - }() + if rate.cleanup() { + ticker.Stop() + } } } }() } +func (rate *Ratelimiter) cleanup() (empty bool) { + rate.mu.Lock() + defer rate.mu.Unlock() + + for key, entry := range rate.tableIPv4 { + entry.mu.Lock() + if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv4, key) + } + entry.mu.Unlock() + } + + for key, entry := range rate.tableIPv6 { + entry.mu.Lock() + if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv6, key) + } + entry.mu.Unlock() + } + + return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 +} + func (rate *Ratelimiter) Allow(ip net.IP) bool { var entry *RatelimiterEntry var keyIPv4 [net.IPv4len]byte @@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { IPv4 := ip.To4() IPv6 := ip.To16() - rate.RLock() + rate.mu.RLock() if IPv4 != nil { copy(keyIPv4[:], IPv4) @@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { entry = rate.tableIPv6[keyIPv6] } - rate.RUnlock() + rate.mu.RUnlock() // make new entry if not found if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost - entry.lastTime = time.Now() - rate.Lock() + entry.lastTime = rate.timeNow() + rate.mu.Lock() if IPv4 != nil { rate.tableIPv4[keyIPv4] = entry if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { @@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { rate.stopReset <- struct{}{} } } - rate.Unlock() + rate.mu.Unlock() return true } // add tokens to entry - entry.Lock() - now := time.Now() + entry.mu.Lock() + now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.lastTime = now if entry.tokens > maxTokens { @@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { if entry.tokens > packetCost { entry.tokens -= packetCost - entry.Unlock() + entry.mu.Unlock() return true } - entry.Unlock() + entry.mu.Unlock() return false } |