summaryrefslogtreecommitdiffhomepage
path: root/ratelimiter
diff options
context:
space:
mode:
authorDavid Crawshaw <crawshaw@tailscale.com>2019-12-08 18:22:31 -0500
committerDavid Crawshaw <david@zentus.com>2020-03-30 18:38:36 +1100
commit9cd8909df2ad882b81b611b4656020aeceb6c9b2 (patch)
tree3e09b88bfdf9069d96b4c98ef39aeaea7b5b1d3a /ratelimiter
parentae88e2a2cda0faab68ad667223cd26ffd54d1bee (diff)
ratelimiter: use a fake clock in tests and style cleanups
The existing test would occasionally flake out with: --- FAIL: TestRatelimiter (0.12s) ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true FAIL FAIL golang.zx2c4.com/wireguard/ratelimiter 0.171s The fake clock also means the tests run much faster, so testing this package with -count=1000 now takes < 100ms. While here, several style cleanups. The most significant one is unembeding the sync.Mutex fields in the rate limiter objects. Embedded as they were, the lock methods were accessible outside the ratelimiter package. As they aren't needed externally, keep them internal to make them easier to reason about. Passes `go test -race -count=10000 ./ratelimiter` Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
Diffstat (limited to 'ratelimiter')
-rw-r--r--ratelimiter/ratelimiter.go99
-rw-r--r--ratelimiter/ratelimiter_test.go54
2 files changed, 88 insertions, 65 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 772c45a..a6d0ea2 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -20,21 +20,23 @@ const (
)
type RatelimiterEntry struct {
- sync.Mutex
+ mu sync.Mutex
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
- sync.RWMutex
- stopReset chan struct{}
+ mu sync.RWMutex
+ timeNow func() time.Time
+
+ stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
if rate.stopReset != nil {
close(rate.stopReset)
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
}
func (rate *Ratelimiter) Init() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
- // stop any ongoing garbage collection routine
+ if rate.timeNow == nil {
+ rate.timeNow = time.Now
+ }
+ // stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
}
@@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() {
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
- // start garbage collection routine
+ stopReset := rate.stopReset // store in case Init is called again.
+ // Start garbage collection routine.
go func() {
ticker := time.NewTicker(time.Second)
ticker.Stop()
for {
select {
- case _, ok := <-rate.stopReset:
+ case _, ok := <-stopReset:
ticker.Stop()
- if ok {
- ticker = time.NewTicker(time.Second)
- } else {
+ if !ok {
return
}
+ ticker = time.NewTicker(time.Second)
case <-ticker.C:
- func() {
- rate.Lock()
- defer rate.Unlock()
-
- for key, entry := range rate.tableIPv4 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
- }
- entry.Unlock()
- }
-
- for key, entry := range rate.tableIPv6 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.Unlock()
- }
-
- if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
- ticker.Stop()
- }
- }()
+ if rate.cleanup() {
+ ticker.Stop()
+ }
}
}
}()
}
+func (rate *Ratelimiter) cleanup() (empty bool) {
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ for key, entry := range rate.tableIPv4 {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv4, key)
+ }
+ entry.mu.Unlock()
+ }
+
+ for key, entry := range rate.tableIPv6 {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv6, key)
+ }
+ entry.mu.Unlock()
+ }
+
+ return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+}
+
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
IPv4 := ip.To4()
IPv6 := ip.To16()
- rate.RLock()
+ rate.mu.RLock()
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
entry = rate.tableIPv6[keyIPv6]
}
- rate.RUnlock()
+ rate.mu.RUnlock()
// make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
- entry.lastTime = time.Now()
- rate.Lock()
+ entry.lastTime = rate.timeNow()
+ rate.mu.Lock()
if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.stopReset <- struct{}{}
}
}
- rate.Unlock()
+ rate.mu.Unlock()
return true
}
// add tokens to entry
- entry.Lock()
- now := time.Now()
+ entry.mu.Lock()
+ now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
if entry.tokens > packetCost {
entry.tokens -= packetCost
- entry.Unlock()
+ entry.mu.Unlock()
return true
}
- entry.Unlock()
+ entry.mu.Unlock()
return false
}
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index 659bdfb..25d5d63 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -11,22 +11,21 @@ import (
"time"
)
-type RatelimiterResult struct {
+type result struct {
allowed bool
text string
wait time.Duration
}
func TestRatelimiter(t *testing.T) {
+ var rate Ratelimiter
+ var expectedResults []result
- var ratelimiter Ratelimiter
- var expectedResults []RatelimiterResult
-
- Nano := func(nano int64) time.Duration {
+ nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano)
}
- Add := func(res RatelimiterResult) {
+ add := func(res result) {
expectedResults = append(
expectedResults,
res,
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
}
for i := 0; i < packetsBurstable; i++ {
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "initial burst",
})
}
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "after burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
+ wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "not having refilled enough",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
+ wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "second packet in 2 packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "packet following 2 packet burst",
})
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
- ratelimiter.Init()
+ now := time.Now()
+ rate.timeNow = func() time.Time {
+ return now
+ }
+ defer func() {
+ // Lock to avoid data race with cleanup goroutine from Init.
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ rate.timeNow = time.Now
+ }()
+ timeSleep := func(d time.Duration) {
+ now = now.Add(d + 1)
+ rate.cleanup()
+ }
+
+ rate.Init()
+ defer rate.Close()
for i, res := range expectedResults {
- time.Sleep(res.wait)
+ timeSleep(res.wait)
for _, ip := range ips {
- allowed := ratelimiter.Allow(ip)
+ allowed := rate.Allow(ip)
if allowed != res.allowed {
- t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
+ t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
}
}
}