diff options
Diffstat (limited to 'ratelimiter/ratelimiter_test.go')
-rw-r--r-- | ratelimiter/ratelimiter_test.go | 54 |
1 files changed, 35 insertions, 19 deletions
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index 659bdfb..25d5d63 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -11,22 +11,21 @@ import ( "time" ) -type RatelimiterResult struct { +type result struct { allowed bool text string wait time.Duration } func TestRatelimiter(t *testing.T) { + var rate Ratelimiter + var expectedResults []result - var ratelimiter Ratelimiter - var expectedResults []RatelimiterResult - - Nano := func(nano int64) time.Duration { + nano := func(nano int64) time.Duration { return time.Nanosecond * time.Duration(nano) } - Add := func(res RatelimiterResult) { + add := func(res result) { expectedResults = append( expectedResults, res, @@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) { } for i := 0; i < packetsBurstable; i++ { - Add(RatelimiterResult{ + add(result{ allowed: true, text: "initial burst", }) } - Add(RatelimiterResult{ + add(result{ allowed: false, text: "after burst", }) - Add(RatelimiterResult{ + add(result{ allowed: true, - wait: Nano(time.Second.Nanoseconds() / packetsPerSecond), + wait: nano(time.Second.Nanoseconds() / packetsPerSecond), text: "filling tokens for single packet", }) - Add(RatelimiterResult{ + add(result{ allowed: false, text: "not having refilled enough", }) - Add(RatelimiterResult{ + add(result{ allowed: true, - wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)), + wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), text: "filling tokens for two packet burst", }) - Add(RatelimiterResult{ + add(result{ allowed: true, text: "second packet in 2 packet burst", }) - Add(RatelimiterResult{ + add(result{ allowed: false, text: "packet following 2 packet burst", }) @@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) { net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } - ratelimiter.Init() + now := time.Now() + rate.timeNow = func() time.Time { + return now + } + defer func() { + // Lock to avoid data race with cleanup goroutine from Init. + rate.mu.Lock() + defer rate.mu.Unlock() + + rate.timeNow = time.Now + }() + timeSleep := func(d time.Duration) { + now = now.Add(d + 1) + rate.cleanup() + } + + rate.Init() + defer rate.Close() for i, res := range expectedResults { - time.Sleep(res.wait) + timeSleep(res.wait) for _, ip := range ips { - allowed := ratelimiter.Allow(ip) + allowed := rate.Allow(ip) if allowed != res.allowed { - t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) + t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) } } } |