diff options
Diffstat (limited to 'pkg/tcpip/faketime')
-rw-r--r-- | pkg/tcpip/faketime/BUILD | 24 | ||||
-rw-r--r-- | pkg/tcpip/faketime/faketime.go | 236 | ||||
-rw-r--r-- | pkg/tcpip/faketime/faketime_test.go | 95 |
3 files changed, 355 insertions, 0 deletions
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD new file mode 100644 index 000000000..114d43df3 --- /dev/null +++ b/pkg/tcpip/faketime/BUILD @@ -0,0 +1,24 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "faketime", + srcs = ["faketime.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "@com_github_dpjacques_clockwork//:go_default_library", + ], +) + +go_test( + name = "faketime_test", + size = "small", + srcs = [ + "faketime_test.go", + ], + deps = [ + "//pkg/tcpip/faketime", + ], +) diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go new file mode 100644 index 000000000..f7a4fbde1 --- /dev/null +++ b/pkg/tcpip/faketime/faketime.go @@ -0,0 +1,236 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package faketime provides a fake clock that implements tcpip.Clock interface. +package faketime + +import ( + "container/heap" + "sync" + "time" + + "github.com/dpjacques/clockwork" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// NullClock implements a clock that never advances. +type NullClock struct{} + +var _ tcpip.Clock = (*NullClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (*NullClock) NowNanoseconds() int64 { + return 0 +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (*NullClock) NowMonotonic() int64 { + return 0 +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { + return nil +} + +// ManualClock implements tcpip.Clock and only advances manually with Advance +// method. +type ManualClock struct { + clock clockwork.FakeClock + + // mu protects the fields below. + mu sync.RWMutex + + // times is min-heap of times. A heap is used for quick retrieval of the next + // upcoming time of scheduled work. + times *timeHeap + + // waitGroups stores one WaitGroup for all work scheduled to execute at the + // same time via AfterFunc. This allows parallel execution of all functions + // passed to AfterFunc scheduled for the same time. + waitGroups map[time.Time]*sync.WaitGroup +} + +// NewManualClock creates a new ManualClock instance. +func NewManualClock() *ManualClock { + return &ManualClock{ + clock: clockwork.NewFakeClock(), + times: &timeHeap{}, + waitGroups: make(map[time.Time]*sync.WaitGroup), + } +} + +var _ tcpip.Clock = (*ManualClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (mc *ManualClock) NowNanoseconds() int64 { + return mc.clock.Now().UnixNano() +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (mc *ManualClock) NowMonotonic() int64 { + return mc.NowNanoseconds() +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := mc.clock.Now().Add(d) + wg := mc.addWait(until) + return &manualTimer{ + clock: mc, + until: until, + timer: mc.clock.AfterFunc(d, func() { + defer wg.Done() + f() + }), + } +} + +// addWait adds an additional wait to the WaitGroup for parallel execution of +// all work scheduled for t. Returns a reference to the WaitGroup modified. +func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup { + mc.mu.RLock() + wg, ok := mc.waitGroups[t] + mc.mu.RUnlock() + + if ok { + wg.Add(1) + return wg + } + + mc.mu.Lock() + heap.Push(mc.times, t) + mc.mu.Unlock() + + wg = &sync.WaitGroup{} + wg.Add(1) + + mc.mu.Lock() + mc.waitGroups[t] = wg + mc.mu.Unlock() + + return wg +} + +// removeWait removes a wait from the WaitGroup for parallel execution of all +// work scheduled for t. +func (mc *ManualClock) removeWait(t time.Time) { + mc.mu.RLock() + defer mc.mu.RUnlock() + + wg := mc.waitGroups[t] + wg.Done() +} + +// Advance executes all work that have been scheduled to execute within d from +// the current time. Blocks until all work has completed execution. +func (mc *ManualClock) Advance(d time.Duration) { + // Block until all the work is done + until := mc.clock.Now().Add(d) + for { + mc.mu.Lock() + if mc.times.Len() == 0 { + mc.mu.Unlock() + break + } + + t := heap.Pop(mc.times).(time.Time) + if t.After(until) { + // No work to do + heap.Push(mc.times, t) + mc.mu.Unlock() + break + } + mc.mu.Unlock() + + diff := t.Sub(mc.clock.Now()) + mc.clock.Advance(diff) + + mc.mu.RLock() + wg := mc.waitGroups[t] + mc.mu.RUnlock() + + wg.Wait() + + mc.mu.Lock() + delete(mc.waitGroups, t) + mc.mu.Unlock() + } + if now := mc.clock.Now(); until.After(now) { + mc.clock.Advance(until.Sub(now)) + } +} + +type manualTimer struct { + clock *ManualClock + timer clockwork.Timer + + mu sync.RWMutex + until time.Time +} + +var _ tcpip.Timer = (*manualTimer)(nil) + +// Reset implements tcpip.Timer.Reset. +func (t *manualTimer) Reset(d time.Duration) { + if !t.timer.Reset(d) { + return + } + + t.mu.Lock() + defer t.mu.Unlock() + + t.clock.removeWait(t.until) + t.until = t.clock.clock.Now().Add(d) + t.clock.addWait(t.until) +} + +// Stop implements tcpip.Timer.Stop. +func (t *manualTimer) Stop() bool { + if !t.timer.Stop() { + return false + } + + t.mu.RLock() + defer t.mu.RUnlock() + + t.clock.removeWait(t.until) + return true +} + +type timeHeap []time.Time + +var _ heap.Interface = (*timeHeap)(nil) + +func (h timeHeap) Len() int { + return len(h) +} + +func (h timeHeap) Less(i, j int) bool { + return h[i].Before(h[j]) +} + +func (h timeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *timeHeap) Push(x interface{}) { + *h = append(*h, x.(time.Time)) +} + +func (h *timeHeap) Pop() interface{} { + last := (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] + return last +} diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go new file mode 100644 index 000000000..c2704df2c --- /dev/null +++ b/pkg/tcpip/faketime/faketime_test.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package faketime_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" +) + +func TestManualClockAdvance(t *testing.T) { + const timeout = time.Millisecond + clock := faketime.NewManualClock() + start := clock.NowMonotonic() + clock.Advance(timeout) + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + t.Errorf("got = %d, want = %d", got, want) + } +} + +func TestManualClockAfterFunc(t *testing.T) { + const ( + timeout1 = time.Millisecond // timeout for counter1 + timeout2 = 2 * time.Millisecond // timeout for counter2 + ) + tests := []struct { + name string + advance time.Duration + wantCounter1 int + wantCounter2 int + }{ + { + name: "before timeout1", + advance: timeout1 - 1, + wantCounter1: 0, + wantCounter2: 0, + }, + { + name: "timeout1", + advance: timeout1, + wantCounter1: 1, + wantCounter2: 0, + }, + { + name: "timeout2", + advance: timeout2, + wantCounter1: 1, + wantCounter2: 1, + }, + { + name: "after timeout2", + advance: timeout2 + 1, + wantCounter1: 1, + wantCounter2: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + counter1 := 0 + counter2 := 0 + clock.AfterFunc(timeout1, func() { + counter1++ + }) + clock.AfterFunc(timeout2, func() { + counter2++ + }) + start := clock.NowMonotonic() + clock.Advance(test.advance) + if got, want := counter1, test.wantCounter1; got != want { + t.Errorf("got counter1 = %d, want = %d", got, want) + } + if got, want := counter2, test.wantCounter2; got != want { + t.Errorf("got counter2 = %d, want = %d", got, want) + } + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + t.Errorf("got elapsed = %d, want = %d", got, want) + } + }) + } +} |