summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/faketime
diff options
context:
space:
mode:
authorToshi Kikuchi <toshik@google.com>2020-09-22 12:43:28 -0700
committergVisor bot <gvisor-bot@google.com>2020-09-22 12:45:23 -0700
commit6e5ea605f4ca9fb8c29adc9510edc980f844ddfc (patch)
tree3b607872fb00ccf17bbbdaae7204a3ef8f94970c /pkg/tcpip/faketime
parent13a9a622e13ccdda76ed02d3de99b565212f6b2f (diff)
Move stack.fakeClock into a separate package
PiperOrigin-RevId: 333138701
Diffstat (limited to 'pkg/tcpip/faketime')
-rw-r--r--pkg/tcpip/faketime/BUILD24
-rw-r--r--pkg/tcpip/faketime/faketime.go216
-rw-r--r--pkg/tcpip/faketime/faketime_test.go95
3 files changed, 335 insertions, 0 deletions
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD
new file mode 100644
index 000000000..114d43df3
--- /dev/null
+++ b/pkg/tcpip/faketime/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "faketime",
+ srcs = ["faketime.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_dpjacques_clockwork//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "faketime_test",
+ size = "small",
+ srcs = [
+ "faketime_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip/faketime",
+ ],
+)
diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go
new file mode 100644
index 000000000..1193f1d7d
--- /dev/null
+++ b/pkg/tcpip/faketime/faketime.go
@@ -0,0 +1,216 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package faketime provides a fake clock that implements tcpip.Clock interface.
+package faketime
+
+import (
+ "container/heap"
+ "sync"
+ "time"
+
+ "github.com/dpjacques/clockwork"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// ManualClock implements tcpip.Clock and only advances manually with Advance
+// method.
+type ManualClock struct {
+ clock clockwork.FakeClock
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // times is min-heap of times. A heap is used for quick retrieval of the next
+ // upcoming time of scheduled work.
+ times *timeHeap
+
+ // waitGroups stores one WaitGroup for all work scheduled to execute at the
+ // same time via AfterFunc. This allows parallel execution of all functions
+ // passed to AfterFunc scheduled for the same time.
+ waitGroups map[time.Time]*sync.WaitGroup
+}
+
+// NewManualClock creates a new ManualClock instance.
+func NewManualClock() *ManualClock {
+ return &ManualClock{
+ clock: clockwork.NewFakeClock(),
+ times: &timeHeap{},
+ waitGroups: make(map[time.Time]*sync.WaitGroup),
+ }
+}
+
+var _ tcpip.Clock = (*ManualClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (mc *ManualClock) NowNanoseconds() int64 {
+ return mc.clock.Now().UnixNano()
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (mc *ManualClock) NowMonotonic() int64 {
+ return mc.NowNanoseconds()
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ until := mc.clock.Now().Add(d)
+ wg := mc.addWait(until)
+ return &manualTimer{
+ clock: mc,
+ until: until,
+ timer: mc.clock.AfterFunc(d, func() {
+ defer wg.Done()
+ f()
+ }),
+ }
+}
+
+// addWait adds an additional wait to the WaitGroup for parallel execution of
+// all work scheduled for t. Returns a reference to the WaitGroup modified.
+func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup {
+ mc.mu.RLock()
+ wg, ok := mc.waitGroups[t]
+ mc.mu.RUnlock()
+
+ if ok {
+ wg.Add(1)
+ return wg
+ }
+
+ mc.mu.Lock()
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
+
+ wg = &sync.WaitGroup{}
+ wg.Add(1)
+
+ mc.mu.Lock()
+ mc.waitGroups[t] = wg
+ mc.mu.Unlock()
+
+ return wg
+}
+
+// removeWait removes a wait from the WaitGroup for parallel execution of all
+// work scheduled for t.
+func (mc *ManualClock) removeWait(t time.Time) {
+ mc.mu.RLock()
+ defer mc.mu.RUnlock()
+
+ wg := mc.waitGroups[t]
+ wg.Done()
+}
+
+// Advance executes all work that have been scheduled to execute within d from
+// the current time. Blocks until all work has completed execution.
+func (mc *ManualClock) Advance(d time.Duration) {
+ // Block until all the work is done
+ until := mc.clock.Now().Add(d)
+ for {
+ mc.mu.Lock()
+ if mc.times.Len() == 0 {
+ mc.mu.Unlock()
+ break
+ }
+
+ t := heap.Pop(mc.times).(time.Time)
+ if t.After(until) {
+ // No work to do
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
+ break
+ }
+ mc.mu.Unlock()
+
+ diff := t.Sub(mc.clock.Now())
+ mc.clock.Advance(diff)
+
+ mc.mu.RLock()
+ wg := mc.waitGroups[t]
+ mc.mu.RUnlock()
+
+ wg.Wait()
+
+ mc.mu.Lock()
+ delete(mc.waitGroups, t)
+ mc.mu.Unlock()
+ }
+ if now := mc.clock.Now(); until.After(now) {
+ mc.clock.Advance(until.Sub(now))
+ }
+}
+
+type manualTimer struct {
+ clock *ManualClock
+ timer clockwork.Timer
+
+ mu sync.RWMutex
+ until time.Time
+}
+
+var _ tcpip.Timer = (*manualTimer)(nil)
+
+// Reset implements tcpip.Timer.Reset.
+func (t *manualTimer) Reset(d time.Duration) {
+ if !t.timer.Reset(d) {
+ return
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.clock.removeWait(t.until)
+ t.until = t.clock.clock.Now().Add(d)
+ t.clock.addWait(t.until)
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (t *manualTimer) Stop() bool {
+ if !t.timer.Stop() {
+ return false
+ }
+
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+
+ t.clock.removeWait(t.until)
+ return true
+}
+
+type timeHeap []time.Time
+
+var _ heap.Interface = (*timeHeap)(nil)
+
+func (h timeHeap) Len() int {
+ return len(h)
+}
+
+func (h timeHeap) Less(i, j int) bool {
+ return h[i].Before(h[j])
+}
+
+func (h timeHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+func (h *timeHeap) Push(x interface{}) {
+ *h = append(*h, x.(time.Time))
+}
+
+func (h *timeHeap) Pop() interface{} {
+ last := (*h)[len(*h)-1]
+ *h = (*h)[:len(*h)-1]
+ return last
+}
diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go
new file mode 100644
index 000000000..c2704df2c
--- /dev/null
+++ b/pkg/tcpip/faketime/faketime_test.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package faketime_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+)
+
+func TestManualClockAdvance(t *testing.T) {
+ const timeout = time.Millisecond
+ clock := faketime.NewManualClock()
+ start := clock.NowMonotonic()
+ clock.Advance(timeout)
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want {
+ t.Errorf("got = %d, want = %d", got, want)
+ }
+}
+
+func TestManualClockAfterFunc(t *testing.T) {
+ const (
+ timeout1 = time.Millisecond // timeout for counter1
+ timeout2 = 2 * time.Millisecond // timeout for counter2
+ )
+ tests := []struct {
+ name string
+ advance time.Duration
+ wantCounter1 int
+ wantCounter2 int
+ }{
+ {
+ name: "before timeout1",
+ advance: timeout1 - 1,
+ wantCounter1: 0,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout1",
+ advance: timeout1,
+ wantCounter1: 1,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout2",
+ advance: timeout2,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ {
+ name: "after timeout2",
+ advance: timeout2 + 1,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ counter1 := 0
+ counter2 := 0
+ clock.AfterFunc(timeout1, func() {
+ counter1++
+ })
+ clock.AfterFunc(timeout2, func() {
+ counter2++
+ })
+ start := clock.NowMonotonic()
+ clock.Advance(test.advance)
+ if got, want := counter1, test.wantCounter1; got != want {
+ t.Errorf("got counter1 = %d, want = %d", got, want)
+ }
+ if got, want := counter2, test.wantCounter2; got != want {
+ t.Errorf("got counter2 = %d, want = %d", got, want)
+ }
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want {
+ t.Errorf("got elapsed = %d, want = %d", got, want)
+ }
+ })
+ }
+}