diff options
Diffstat (limited to 'pkg/waiter')
-rw-r--r-- | pkg/waiter/BUILD | 35 | ||||
-rw-r--r-- | pkg/waiter/waiter_list.go | 193 | ||||
-rw-r--r-- | pkg/waiter/waiter_state_autogen.go | 69 | ||||
-rw-r--r-- | pkg/waiter/waiter_test.go | 192 |
4 files changed, 262 insertions, 227 deletions
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD deleted file mode 100644 index 852480a09..000000000 --- a/pkg/waiter/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "waiter_list", - out = "waiter_list.go", - package = "waiter", - prefix = "waiter", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*Entry", - "Linker": "*Entry", - }, -) - -go_library( - name = "waiter", - srcs = [ - "waiter.go", - "waiter_list.go", - ], - visibility = ["//visibility:public"], - deps = ["//pkg/sync"], -) - -go_test( - name = "waiter_test", - size = "small", - srcs = [ - "waiter_test.go", - ], - library = ":waiter", -) diff --git a/pkg/waiter/waiter_list.go b/pkg/waiter/waiter_list.go new file mode 100644 index 000000000..35431f5a4 --- /dev/null +++ b/pkg/waiter/waiter_list.go @@ -0,0 +1,193 @@ +package waiter + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type waiterElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (waiterElementMapper) linkerFor(elem *Entry) *Entry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type waiterList struct { + head *Entry + tail *Entry +} + +// Reset resets list l to the empty state. +func (l *waiterList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *waiterList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *waiterList) Front() *Entry { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *waiterList) Back() *Entry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *waiterList) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +func (l *waiterList) PushFront(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + waiterElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *waiterList) PushBack(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *waiterList) PushBackList(m *waiterList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + waiterElementMapper{}.linkerFor(l.tail).SetNext(m.head) + waiterElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *waiterList) InsertAfter(b, e *Entry) { + bLinker := waiterElementMapper{}.linkerFor(b) + eLinker := waiterElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + waiterElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *waiterList) InsertBefore(a, e *Entry) { + aLinker := waiterElementMapper{}.linkerFor(a) + eLinker := waiterElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + waiterElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *waiterList) Remove(e *Entry) { + linker := waiterElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + waiterElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + waiterElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type waiterEntry struct { + next *Entry + prev *Entry +} + +// Next returns the entry that follows e in the list. +func (e *waiterEntry) Next() *Entry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *waiterEntry) Prev() *Entry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *waiterEntry) SetNext(elem *Entry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *waiterEntry) SetPrev(elem *Entry) { + e.prev = elem +} diff --git a/pkg/waiter/waiter_state_autogen.go b/pkg/waiter/waiter_state_autogen.go new file mode 100644 index 000000000..cf7f5fc2c --- /dev/null +++ b/pkg/waiter/waiter_state_autogen.go @@ -0,0 +1,69 @@ +// automatically generated by stateify. + +package waiter + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (x *Entry) beforeSave() {} +func (x *Entry) save(m state.Map) { + x.beforeSave() + m.Save("Context", &x.Context) + m.Save("Callback", &x.Callback) + m.Save("mask", &x.mask) + m.Save("waiterEntry", &x.waiterEntry) +} + +func (x *Entry) afterLoad() {} +func (x *Entry) load(m state.Map) { + m.Load("Context", &x.Context) + m.Load("Callback", &x.Callback) + m.Load("mask", &x.mask) + m.Load("waiterEntry", &x.waiterEntry) +} + +func (x *Queue) beforeSave() {} +func (x *Queue) save(m state.Map) { + x.beforeSave() + if !state.IsZeroValue(&x.list) { + m.Failf("list is %#v, expected zero", &x.list) + } +} + +func (x *Queue) afterLoad() {} +func (x *Queue) load(m state.Map) { +} + +func (x *waiterList) beforeSave() {} +func (x *waiterList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *waiterList) afterLoad() {} +func (x *waiterList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *waiterEntry) beforeSave() {} +func (x *waiterEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *waiterEntry) afterLoad() {} +func (x *waiterEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("pkg/waiter.Entry", (*Entry)(nil), state.Fns{Save: (*Entry).save, Load: (*Entry).load}) + state.Register("pkg/waiter.Queue", (*Queue)(nil), state.Fns{Save: (*Queue).save, Load: (*Queue).load}) + state.Register("pkg/waiter.waiterList", (*waiterList)(nil), state.Fns{Save: (*waiterList).save, Load: (*waiterList).load}) + state.Register("pkg/waiter.waiterEntry", (*waiterEntry)(nil), state.Fns{Save: (*waiterEntry).save, Load: (*waiterEntry).load}) +} diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go deleted file mode 100644 index c1b94a4f3..000000000 --- a/pkg/waiter/waiter_test.go +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package waiter - -import ( - "sync/atomic" - "testing" -) - -type callbackStub struct { - f func(e *Entry) -} - -// Callback implements EntryCallback.Callback. -func (c *callbackStub) Callback(e *Entry) { - c.f(e) -} - -func TestEmptyQueue(t *testing.T) { - var q Queue - - // Notify the zero-value of a queue. - q.Notify(EventIn) - - // Register then unregister a waiter, then notify the queue. - cnt := 0 - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} - q.EventRegister(&e, EventIn) - q.EventUnregister(&e) - q.Notify(EventIn) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestMask(t *testing.T) { - // Register a waiter. - var q Queue - var cnt int - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} - q.EventRegister(&e, EventIn|EventErr) - - // Notify with an overlapping mask. - cnt = 0 - q.Notify(EventIn | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a subset mask. - cnt = 0 - q.Notify(EventIn) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a superset mask. - cnt = 0 - q.Notify(EventIn | EventErr | EventOut) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with the exact same mask. - cnt = 0 - q.Notify(EventIn | EventErr) - if cnt != 1 { - t.Errorf("Callback wasn't called when it should have been") - } - - // Notify with a disjoint mask. - cnt = 0 - q.Notify(EventOut | EventHUp) - if cnt != 0 { - t.Errorf("Callback was called when it shouldn't have been") - } -} - -func TestConcurrentRegistration(t *testing.T) { - var q Queue - var cnt int - const concurrency = 1000 - - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - ch3 := make(chan struct{}) - - // Create goroutines that will all register/unregister concurrently. - for i := 0; i < concurrency; i++ { - go func() { - var e Entry - e.Callback = &callbackStub{func(entry *Entry) { - cnt++ - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - }} - - // Wait for notification, then register. - <-ch1 - q.EventRegister(&e, EventIn|EventErr) - - // Tell main goroutine that we're done registering. - ch2 <- struct{}{} - - // Wait for notification, then unregister. - <-ch3 - q.EventUnregister(&e) - - // Tell main goroutine that we're done unregistering. - ch2 <- struct{}{} - }() - } - - // Let the goroutines register. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } - - // Let the goroutine unregister. - close(ch3) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Issue a notification. - q.Notify(EventIn) - if cnt != concurrency { - t.Errorf("cnt = %d, want %d", cnt, concurrency) - } -} - -func TestConcurrentNotification(t *testing.T) { - var q Queue - var cnt int32 - const concurrency = 1000 - const waiterCount = 1000 - - // Register waiters. - for i := 0; i < waiterCount; i++ { - var e Entry - e.Callback = &callbackStub{func(entry *Entry) { - atomic.AddInt32(&cnt, 1) - if entry != &e { - t.Errorf("entry = %p, want %p", entry, &e) - } - }} - - q.EventRegister(&e, EventIn|EventErr) - } - - // Launch notifiers. - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - for i := 0; i < concurrency; i++ { - go func() { - <-ch1 - q.Notify(EventIn) - ch2 <- struct{}{} - }() - } - - // Let notifiers go. - close(ch1) - for i := 0; i < concurrency; i++ { - <-ch2 - } - - // Check the count. - if cnt != concurrency*waiterCount { - t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) - } -} |