diff options
Diffstat (limited to 'pkg/waiter')
-rw-r--r-- | pkg/waiter/BUILD | 35 | ||||
-rw-r--r-- | pkg/waiter/waiter.go | 250 | ||||
-rw-r--r-- | pkg/waiter/waiter_test.go | 192 |
3 files changed, 477 insertions, 0 deletions
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD new file mode 100644 index 000000000..852480a09 --- /dev/null +++ b/pkg/waiter/BUILD @@ -0,0 +1,35 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "waiter_list", + out = "waiter_list.go", + package = "waiter", + prefix = "waiter", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Entry", + "Linker": "*Entry", + }, +) + +go_library( + name = "waiter", + srcs = [ + "waiter.go", + "waiter_list.go", + ], + visibility = ["//visibility:public"], + deps = ["//pkg/sync"], +) + +go_test( + name = "waiter_test", + size = "small", + srcs = [ + "waiter_test.go", + ], + library = ":waiter", +) diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go new file mode 100644 index 000000000..707eb085b --- /dev/null +++ b/pkg/waiter/waiter.go @@ -0,0 +1,250 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package waiter provides the implementation of a wait queue, where waiters can +// be enqueued to be notified when an event of interest happens. +// +// Becoming readable and/or writable are examples of events. Waiters are +// expected to use a pattern similar to this to make a blocking function out of +// a non-blocking one: +// +// func (o *object) blockingRead(...) error { +// err := o.nonBlockingRead(...) +// if err != ErrAgain { +// // Completed with no need to wait! +// return err +// } +// +// e := createOrGetWaiterEntry(...) +// o.EventRegister(&e, waiter.EventIn) +// defer o.EventUnregister(&e) +// +// // We need to try to read again after registration because the +// // object may have become readable between the last attempt to +// // read and read registration. +// err = o.nonBlockingRead(...) +// for err == ErrAgain { +// wait() +// err = o.nonBlockingRead(...) +// } +// +// return err +// } +// +// Another goroutine needs to notify waiters when events happen. For example: +// +// func (o *object) Write(...) ... { +// // Do write work. +// [...] +// +// if oldDataAvailableSize == 0 && dataAvailableSize > 0 { +// // If no data was available and now some data is +// // available, the object became readable, so notify +// // potential waiters about this. +// o.Notify(waiter.EventIn) +// } +// } +package waiter + +import ( + "gvisor.dev/gvisor/pkg/sync" +) + +// EventMask represents io events as used in the poll() syscall. +type EventMask uint64 + +// Events that waiters can wait on. The meaning is the same as those in the +// poll() syscall. +const ( + EventIn EventMask = 0x01 // POLLIN + EventPri EventMask = 0x02 // POLLPRI + EventOut EventMask = 0x04 // POLLOUT + EventErr EventMask = 0x08 // POLLERR + EventHUp EventMask = 0x10 // POLLHUP + + allEvents EventMask = 0x1f +) + +// EventMaskFromLinux returns an EventMask representing the supported events +// from the Linux events e, which is in the format used by poll(2). +func EventMaskFromLinux(e uint32) EventMask { + // Our flag definitions are currently identical to Linux. + return EventMask(e) & allEvents +} + +// ToLinux returns e in the format used by Linux poll(2). +func (e EventMask) ToLinux() uint32 { + // Our flag definitions are currently identical to Linux. + return uint32(e) +} + +// Waitable contains the methods that need to be implemented by waitable +// objects. +type Waitable interface { + // Readiness returns what the object is currently ready for. If it's + // not ready for a desired purpose, the caller may use EventRegister and + // EventUnregister to get notifications once the object becomes ready. + // + // Implementations should allow for events like EventHUp and EventErr + // to be returned regardless of whether they are in the input EventMask. + Readiness(mask EventMask) EventMask + + // EventRegister registers the given waiter entry to receive + // notifications when an event occurs that makes the object ready for + // at least one of the events in mask. + EventRegister(e *Entry, mask EventMask) + + // EventUnregister unregisters a waiter entry previously registered with + // EventRegister(). + EventUnregister(e *Entry) +} + +// EntryCallback provides a notify callback. +type EntryCallback interface { + // Callback is the function to be called when the waiter entry is + // notified. It is responsible for doing whatever is needed to wake up + // the waiter. + // + // The callback is supposed to perform minimal work, and cannot call + // any method on the queue itself because it will be locked while the + // callback is running. + Callback(e *Entry) +} + +// Entry represents a waiter that can be add to the a wait queue. It can +// only be in one queue at a time, and is added "intrusively" to the queue with +// no extra memory allocations. +// +// +stateify savable +type Entry struct { + // Context stores any state the waiter may wish to store in the entry + // itself, which may be used at wake up time. + // + // Note that use of this field is optional and state may alternatively be + // stored in the callback itself. + Context interface{} + + Callback EntryCallback + + // The following fields are protected by the queue lock. + mask EventMask + waiterEntry +} + +type channelCallback struct{} + +// Callback implements EntryCallback.Callback. +func (*channelCallback) Callback(e *Entry) { + ch := e.Context.(chan struct{}) + select { + case ch <- struct{}{}: + default: + } +} + +// NewChannelEntry initializes a new Entry that does a non-blocking write to a +// struct{} channel when the callback is called. It returns the new Entry +// instance and the channel being used. +// +// If a channel isn't specified (i.e., if "c" is nil), then NewChannelEntry +// allocates a new channel. +func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { + if c == nil { + c = make(chan struct{}, 1) + } + + return Entry{Context: c, Callback: &channelCallback{}}, c +} + +// Queue represents the wait queue where waiters can be added and +// notifiers can notify them when events happen. +// +// The zero value for waiter.Queue is an empty queue ready for use. +// +// +stateify savable +type Queue struct { + list waiterList `state:"zerovalue"` + mu sync.RWMutex `state:"nosave"` +} + +// EventRegister adds a waiter to the wait queue; the waiter will be notified +// when at least one of the events specified in mask happens. +func (q *Queue) EventRegister(e *Entry, mask EventMask) { + q.mu.Lock() + e.mask = mask + q.list.PushBack(e) + q.mu.Unlock() +} + +// EventUnregister removes the given waiter entry from the wait queue. +func (q *Queue) EventUnregister(e *Entry) { + q.mu.Lock() + q.list.Remove(e) + q.mu.Unlock() +} + +// Notify notifies all waiters in the queue whose masks have at least one bit +// in common with the notification mask. +func (q *Queue) Notify(mask EventMask) { + q.mu.RLock() + for e := q.list.Front(); e != nil; e = e.Next() { + if mask&e.mask != 0 { + e.Callback.Callback(e) + } + } + q.mu.RUnlock() +} + +// Events returns the set of events being waited on. It is the union of the +// masks of all registered entries. +func (q *Queue) Events() EventMask { + ret := EventMask(0) + + q.mu.RLock() + for e := q.list.Front(); e != nil; e = e.Next() { + ret |= e.mask + } + q.mu.RUnlock() + + return ret +} + +// IsEmpty returns if the wait queue is empty or not. +func (q *Queue) IsEmpty() bool { + q.mu.Lock() + defer q.mu.Unlock() + + return q.list.Front() == nil +} + +// AlwaysReady implements the Waitable interface but is always ready. Embedding +// this struct into another struct makes it implement the boilerplate empty +// functions automatically. +type AlwaysReady struct { +} + +// Readiness always returns the input mask because this object is always ready. +func (*AlwaysReady) Readiness(mask EventMask) EventMask { + return mask +} + +// EventRegister doesn't do anything because this object doesn't need to issue +// notifications because its readiness never changes. +func (*AlwaysReady) EventRegister(*Entry, EventMask) { +} + +// EventUnregister doesn't do anything because this object doesn't need to issue +// notifications because its readiness never changes. +func (*AlwaysReady) EventUnregister(e *Entry) { +} diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go new file mode 100644 index 000000000..c1b94a4f3 --- /dev/null +++ b/pkg/waiter/waiter_test.go @@ -0,0 +1,192 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package waiter + +import ( + "sync/atomic" + "testing" +) + +type callbackStub struct { + f func(e *Entry) +} + +// Callback implements EntryCallback.Callback. +func (c *callbackStub) Callback(e *Entry) { + c.f(e) +} + +func TestEmptyQueue(t *testing.T) { + var q Queue + + // Notify the zero-value of a queue. + q.Notify(EventIn) + + // Register then unregister a waiter, then notify the queue. + cnt := 0 + e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + q.EventRegister(&e, EventIn) + q.EventUnregister(&e) + q.Notify(EventIn) + if cnt != 0 { + t.Errorf("Callback was called when it shouldn't have been") + } +} + +func TestMask(t *testing.T) { + // Register a waiter. + var q Queue + var cnt int + e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + q.EventRegister(&e, EventIn|EventErr) + + // Notify with an overlapping mask. + cnt = 0 + q.Notify(EventIn | EventOut) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a subset mask. + cnt = 0 + q.Notify(EventIn) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a superset mask. + cnt = 0 + q.Notify(EventIn | EventErr | EventOut) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with the exact same mask. + cnt = 0 + q.Notify(EventIn | EventErr) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a disjoint mask. + cnt = 0 + q.Notify(EventOut | EventHUp) + if cnt != 0 { + t.Errorf("Callback was called when it shouldn't have been") + } +} + +func TestConcurrentRegistration(t *testing.T) { + var q Queue + var cnt int + const concurrency = 1000 + + ch1 := make(chan struct{}) + ch2 := make(chan struct{}) + ch3 := make(chan struct{}) + + // Create goroutines that will all register/unregister concurrently. + for i := 0; i < concurrency; i++ { + go func() { + var e Entry + e.Callback = &callbackStub{func(entry *Entry) { + cnt++ + if entry != &e { + t.Errorf("entry = %p, want %p", entry, &e) + } + }} + + // Wait for notification, then register. + <-ch1 + q.EventRegister(&e, EventIn|EventErr) + + // Tell main goroutine that we're done registering. + ch2 <- struct{}{} + + // Wait for notification, then unregister. + <-ch3 + q.EventUnregister(&e) + + // Tell main goroutine that we're done unregistering. + ch2 <- struct{}{} + }() + } + + // Let the goroutines register. + close(ch1) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Issue a notification. + q.Notify(EventIn) + if cnt != concurrency { + t.Errorf("cnt = %d, want %d", cnt, concurrency) + } + + // Let the goroutine unregister. + close(ch3) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Issue a notification. + q.Notify(EventIn) + if cnt != concurrency { + t.Errorf("cnt = %d, want %d", cnt, concurrency) + } +} + +func TestConcurrentNotification(t *testing.T) { + var q Queue + var cnt int32 + const concurrency = 1000 + const waiterCount = 1000 + + // Register waiters. + for i := 0; i < waiterCount; i++ { + var e Entry + e.Callback = &callbackStub{func(entry *Entry) { + atomic.AddInt32(&cnt, 1) + if entry != &e { + t.Errorf("entry = %p, want %p", entry, &e) + } + }} + + q.EventRegister(&e, EventIn|EventErr) + } + + // Launch notifiers. + ch1 := make(chan struct{}) + ch2 := make(chan struct{}) + for i := 0; i < concurrency; i++ { + go func() { + <-ch1 + q.Notify(EventIn) + ch2 <- struct{}{} + }() + } + + // Let notifiers go. + close(ch1) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Check the count. + if cnt != concurrency*waiterCount { + t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) + } +} |