summaryrefslogtreecommitdiffhomepage
path: root/ratelimiter
diff options
context:
space:
mode:
Diffstat (limited to 'ratelimiter')
-rw-r--r--ratelimiter/ratelimiter.go99
-rw-r--r--ratelimiter/ratelimiter_test.go54
2 files changed, 88 insertions, 65 deletions
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 772c45a..a6d0ea2 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -20,21 +20,23 @@ const (
)
type RatelimiterEntry struct {
- sync.Mutex
+ mu sync.Mutex
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
- sync.RWMutex
- stopReset chan struct{}
+ mu sync.RWMutex
+ timeNow func() time.Time
+
+ stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
if rate.stopReset != nil {
close(rate.stopReset)
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
}
func (rate *Ratelimiter) Init() {
- rate.Lock()
- defer rate.Unlock()
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
- // stop any ongoing garbage collection routine
+ if rate.timeNow == nil {
+ rate.timeNow = time.Now
+ }
+ // stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
}
@@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() {
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
- // start garbage collection routine
+ stopReset := rate.stopReset // store in case Init is called again.
+ // Start garbage collection routine.
go func() {
ticker := time.NewTicker(time.Second)
ticker.Stop()
for {
select {
- case _, ok := <-rate.stopReset:
+ case _, ok := <-stopReset:
ticker.Stop()
- if ok {
- ticker = time.NewTicker(time.Second)
- } else {
+ if !ok {
return
}
+ ticker = time.NewTicker(time.Second)
case <-ticker.C:
- func() {
- rate.Lock()
- defer rate.Unlock()
-
- for key, entry := range rate.tableIPv4 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
- }
- entry.Unlock()
- }
-
- for key, entry := range rate.tableIPv6 {
- entry.Lock()
- if time.Since(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.Unlock()
- }
-
- if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
- ticker.Stop()
- }
- }()
+ if rate.cleanup() {
+ ticker.Stop()
+ }
}
}
}()
}
+func (rate *Ratelimiter) cleanup() (empty bool) {
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ for key, entry := range rate.tableIPv4 {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv4, key)
+ }
+ entry.mu.Unlock()
+ }
+
+ for key, entry := range rate.tableIPv6 {
+ entry.mu.Lock()
+ if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
+ delete(rate.tableIPv6, key)
+ }
+ entry.mu.Unlock()
+ }
+
+ return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+}
+
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
IPv4 := ip.To4()
IPv6 := ip.To16()
- rate.RLock()
+ rate.mu.RLock()
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
entry = rate.tableIPv6[keyIPv6]
}
- rate.RUnlock()
+ rate.mu.RUnlock()
// make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
- entry.lastTime = time.Now()
- rate.Lock()
+ entry.lastTime = rate.timeNow()
+ rate.mu.Lock()
if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.stopReset <- struct{}{}
}
}
- rate.Unlock()
+ rate.mu.Unlock()
return true
}
// add tokens to entry
- entry.Lock()
- now := time.Now()
+ entry.mu.Lock()
+ now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
if entry.tokens > packetCost {
entry.tokens -= packetCost
- entry.Unlock()
+ entry.mu.Unlock()
return true
}
- entry.Unlock()
+ entry.mu.Unlock()
return false
}
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index 659bdfb..25d5d63 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -11,22 +11,21 @@ import (
"time"
)
-type RatelimiterResult struct {
+type result struct {
allowed bool
text string
wait time.Duration
}
func TestRatelimiter(t *testing.T) {
+ var rate Ratelimiter
+ var expectedResults []result
- var ratelimiter Ratelimiter
- var expectedResults []RatelimiterResult
-
- Nano := func(nano int64) time.Duration {
+ nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano)
}
- Add := func(res RatelimiterResult) {
+ add := func(res result) {
expectedResults = append(
expectedResults,
res,
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
}
for i := 0; i < packetsBurstable; i++ {
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "initial burst",
})
}
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "after burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
+ wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "not having refilled enough",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
- wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
+ wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: true,
text: "second packet in 2 packet burst",
})
- Add(RatelimiterResult{
+ add(result{
allowed: false,
text: "packet following 2 packet burst",
})
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
- ratelimiter.Init()
+ now := time.Now()
+ rate.timeNow = func() time.Time {
+ return now
+ }
+ defer func() {
+ // Lock to avoid data race with cleanup goroutine from Init.
+ rate.mu.Lock()
+ defer rate.mu.Unlock()
+
+ rate.timeNow = time.Now
+ }()
+ timeSleep := func(d time.Duration) {
+ now = now.Add(d + 1)
+ rate.cleanup()
+ }
+
+ rate.Init()
+ defer rate.Close()
for i, res := range expectedResults {
- time.Sleep(res.wait)
+ timeSleep(res.wait)
for _, ip := range ips {
- allowed := ratelimiter.Allow(ip)
+ allowed := rate.Allow(ip)
if allowed != res.allowed {
- t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
+ t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
}
}
}