summaryrefslogtreecommitdiffhomepage
path: root/pkg/waiter
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/waiter')
-rw-r--r--pkg/waiter/BUILD36
-rw-r--r--pkg/waiter/fdnotifier/BUILD14
-rw-r--r--pkg/waiter/fdnotifier/fdnotifier.go200
-rw-r--r--pkg/waiter/fdnotifier/poll_unsafe.go74
-rw-r--r--pkg/waiter/waiter.go226
-rw-r--r--pkg/waiter/waiter_test.go182
6 files changed, 732 insertions, 0 deletions
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
new file mode 100644
index 000000000..7415dd325
--- /dev/null
+++ b/pkg/waiter/BUILD
@@ -0,0 +1,36 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "waiter_state",
+ srcs = [
+ "waiter.go",
+ ],
+ out = "waiter_state.go",
+ package = "waiter",
+)
+
+go_library(
+ name = "waiter",
+ srcs = [
+ "waiter.go",
+ "waiter_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/waiter",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/ilist",
+ "//pkg/state",
+ ],
+)
+
+go_test(
+ name = "waiter_test",
+ size = "small",
+ srcs = [
+ "waiter_test.go",
+ ],
+ embed = [":waiter"],
+)
diff --git a/pkg/waiter/fdnotifier/BUILD b/pkg/waiter/fdnotifier/BUILD
new file mode 100644
index 000000000..d5b5ee82d
--- /dev/null
+++ b/pkg/waiter/fdnotifier/BUILD
@@ -0,0 +1,14 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "fdnotifier",
+ srcs = [
+ "fdnotifier.go",
+ "poll_unsafe.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier",
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/waiter"],
+)
diff --git a/pkg/waiter/fdnotifier/fdnotifier.go b/pkg/waiter/fdnotifier/fdnotifier.go
new file mode 100644
index 000000000..8bb93e39b
--- /dev/null
+++ b/pkg/waiter/fdnotifier/fdnotifier.go
@@ -0,0 +1,200 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fdnotifier contains an adapter that translates IO events (e.g., a
+// file became readable/writable) from native FDs to the notifications in the
+// waiter package. It uses epoll in edge-triggered mode to receive notifications
+// for registered FDs.
+package fdnotifier
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type fdInfo struct {
+ queue *waiter.Queue
+ waiting bool
+}
+
+// notifier holds all the state necessary to issue notifications when IO events
+// occur in the observed FDs.
+type notifier struct {
+ // epFD is the epoll file descriptor used to register for io
+ // notifications.
+ epFD int
+
+ // mu protects fdMap.
+ mu sync.Mutex
+
+ // fdMap maps file descriptors to their notification queues and waiting
+ // status.
+ fdMap map[int32]*fdInfo
+}
+
+// newNotifier creates a new notifier object.
+func newNotifier() (*notifier, error) {
+ epfd, err := syscall.EpollCreate1(0)
+ if err != nil {
+ return nil, err
+ }
+
+ w := &notifier{
+ epFD: epfd,
+ fdMap: make(map[int32]*fdInfo),
+ }
+
+ go w.waitAndNotify() // S/R-SAFE: no waiter exists during save / load.
+
+ return w, nil
+}
+
+// waitFD waits on mask for fd. The fdMap mutex must be hold.
+func (n *notifier) waitFD(fd int32, fi *fdInfo, mask waiter.EventMask) error {
+ if !fi.waiting && mask == 0 {
+ return nil
+ }
+
+ e := syscall.EpollEvent{
+ Events: uint32(mask) | -syscall.EPOLLET,
+ Fd: fd,
+ }
+
+ switch {
+ case !fi.waiting && mask != 0:
+ if err := syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_ADD, int(fd), &e); err != nil {
+ return err
+ }
+ fi.waiting = true
+ case fi.waiting && mask == 0:
+ syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_DEL, int(fd), nil)
+ fi.waiting = false
+ case fi.waiting && mask != 0:
+ if err := syscall.EpollCtl(n.epFD, syscall.EPOLL_CTL_MOD, int(fd), &e); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// addFD adds an FD to the list of FDs observed by n.
+func (n *notifier) addFD(fd int32, queue *waiter.Queue) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Panic if we're already notifying on this FD.
+ if _, ok := n.fdMap[fd]; ok {
+ panic(fmt.Sprintf("File descriptor %v added twice", fd))
+ }
+
+ // We have nothing to wait for at the moment. Just add it to the map.
+ n.fdMap[fd] = &fdInfo{queue: queue}
+}
+
+// updateFD updates the set of events the fd needs to be notified on.
+func (n *notifier) updateFD(fd int32) error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if fi, ok := n.fdMap[fd]; ok {
+ return n.waitFD(fd, fi, fi.queue.Events())
+ }
+
+ return nil
+}
+
+// RemoveFD removes an FD from the list of FDs observed by n.
+func (n *notifier) removeFD(fd int32) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Remove from map, then from epoll object.
+ n.waitFD(fd, n.fdMap[fd], 0)
+ delete(n.fdMap, fd)
+}
+
+// hasFD returns true if the fd is in the list of observed FDs.
+func (n *notifier) hasFD(fd int32) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ _, ok := n.fdMap[fd]
+ return ok
+}
+
+// waitAndNotify run is its own goroutine and loops waiting for io event
+// notifications from the epoll object. Once notifications arrive, they are
+// dispatched to the registered queue.
+func (n *notifier) waitAndNotify() error {
+ e := make([]syscall.EpollEvent, 100)
+ for {
+ v, err := epollWait(n.epFD, e, -1)
+ if err == syscall.EINTR {
+ continue
+ }
+
+ if err != nil {
+ return err
+ }
+
+ n.mu.Lock()
+ for i := 0; i < v; i++ {
+ if fi, ok := n.fdMap[e[i].Fd]; ok {
+ fi.queue.Notify(waiter.EventMask(e[i].Events))
+ }
+ }
+ n.mu.Unlock()
+ }
+}
+
+var shared struct {
+ notifier *notifier
+ once sync.Once
+ initErr error
+}
+
+// AddFD adds an FD to the list of observed FDs.
+func AddFD(fd int32, queue *waiter.Queue) error {
+ shared.once.Do(func() {
+ shared.notifier, shared.initErr = newNotifier()
+ })
+
+ if shared.initErr != nil {
+ return shared.initErr
+ }
+
+ shared.notifier.addFD(fd, queue)
+ return nil
+}
+
+// UpdateFD updates the set of events the fd needs to be notified on.
+func UpdateFD(fd int32) error {
+ return shared.notifier.updateFD(fd)
+}
+
+// RemoveFD removes an FD from the list of observed FDs.
+func RemoveFD(fd int32) {
+ shared.notifier.removeFD(fd)
+}
+
+// HasFD returns true if the FD is in the list of observed FDs.
+//
+// This should only be used by tests to assert that FDs are correctly registered.
+func HasFD(fd int32) bool {
+ return shared.notifier.hasFD(fd)
+}
diff --git a/pkg/waiter/fdnotifier/poll_unsafe.go b/pkg/waiter/fdnotifier/poll_unsafe.go
new file mode 100644
index 000000000..26bca2b53
--- /dev/null
+++ b/pkg/waiter/fdnotifier/poll_unsafe.go
@@ -0,0 +1,74 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdnotifier
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// NonBlockingPoll polls the given FD in non-blocking fashion. It is used just
+// to query the FD's current state.
+func NonBlockingPoll(fd int32, mask waiter.EventMask) waiter.EventMask {
+ e := struct {
+ fd int32
+ events int16
+ revents int16
+ }{
+ fd: fd,
+ events: int16(mask),
+ }
+
+ for {
+ n, _, err := syscall.RawSyscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&e)), 1, 0)
+ // Interrupted by signal, try again.
+ if err == syscall.EINTR {
+ continue
+ }
+ // If an error occur we'll conservatively say the FD is ready for
+ // whatever is being checked.
+ if err != 0 {
+ return mask
+ }
+
+ // If no FDs were returned, it wasn't ready for anything.
+ if n == 0 {
+ return 0
+ }
+
+ // Otherwise we got the ready events in the revents field.
+ return waiter.EventMask(e.revents)
+ }
+}
+
+// epollWait performs a blocking wait on epfd.
+//
+// Preconditions:
+// * len(events) > 0
+func epollWait(epfd int, events []syscall.EpollEvent, msec int) (int, error) {
+ if len(events) == 0 {
+ panic("Empty events passed to EpollWait")
+ }
+
+ // We actually use epoll_pwait with NULL sigmask instead of epoll_wait
+ // since that is what the Go >= 1.11 runtime prefers.
+ r, _, e := syscall.Syscall6(syscall.SYS_EPOLL_PWAIT, uintptr(epfd), uintptr(unsafe.Pointer(&events[0])), uintptr(len(events)), uintptr(msec), 0, 0)
+ if e != 0 {
+ return 0, e
+ }
+ return int(r), nil
+}
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
new file mode 100644
index 000000000..56f53f9c3
--- /dev/null
+++ b/pkg/waiter/waiter.go
@@ -0,0 +1,226 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package waiter provides the implementation of a wait queue, where waiters can
+// be enqueued to be notified when an event of interest happens.
+//
+// Becoming readable and/or writable are examples of events. Waiters are
+// expected to use a pattern similar to this to make a blocking function out of
+// a non-blocking one:
+//
+// func (o *object) blockingRead(...) error {
+// err := o.nonBlockingRead(...)
+// if err != ErrAgain {
+// // Completed with no need to wait!
+// return err
+// }
+//
+// e := createOrGetWaiterEntry(...)
+// o.EventRegister(&e, waiter.EventIn)
+// defer o.EventUnregister(&e)
+//
+// // We need to try to read again after registration because the
+// // object may have become readable between the last attempt to
+// // read and read registration.
+// err = o.nonBlockingRead(...)
+// for err == ErrAgain {
+// wait()
+// err = o.nonBlockingRead(...)
+// }
+//
+// return err
+// }
+//
+// Another goroutine needs to notify waiters when events happen. For example:
+//
+// func (o *object) Write(...) ... {
+// // Do write work.
+// [...]
+//
+// if oldDataAvailableSize == 0 && dataAvailableSize > 0 {
+// // If no data was available and now some data is
+// // available, the object became readable, so notify
+// // potential waiters about this.
+// o.Notify(waiter.EventIn)
+// }
+// }
+package waiter
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+)
+
+// EventMask represents io events as used in the poll() syscall.
+type EventMask uint16
+
+// Events that waiters can wait on. The meaning is the same as those in the
+// poll() syscall.
+const (
+ EventIn EventMask = 0x01 // syscall.EPOLLIN
+ EventPri EventMask = 0x02 // syscall.EPOLLPRI
+ EventOut EventMask = 0x04 // syscall.EPOLLOUT
+ EventErr EventMask = 0x08 // syscall.EPOLLERR
+ EventHUp EventMask = 0x10 // syscall.EPOLLHUP
+ EventNVal EventMask = 0x20 // Not defined in syscall.
+)
+
+// Waitable contains the methods that need to be implemented by waitable
+// objects.
+type Waitable interface {
+ // Readiness returns what the object is currently ready for. If it's
+ // not ready for a desired purpose, the caller may use EventRegister and
+ // EventUnregister to get notifications once the object becomes ready.
+ //
+ // Implementations should allow for events like EventHUp and EventErr
+ // to be returned regardless of whether they are in the input EventMask.
+ Readiness(mask EventMask) EventMask
+
+ // EventRegister registers the given waiter entry to receive
+ // notifications when an event occurs that makes the object ready for
+ // at least one of the events in mask.
+ EventRegister(e *Entry, mask EventMask)
+
+ // EventUnregister unregisters a waiter entry previously registered with
+ // EventRegister().
+ EventUnregister(e *Entry)
+}
+
+// EntryCallback provides a notify callback.
+type EntryCallback interface {
+ // Callback is the function to be called when the waiter entry is
+ // notified. It is responsible for doing whatever is needed to wake up
+ // the waiter.
+ //
+ // The callback is supposed to perform minimal work, and cannot call
+ // any method on the queue itself because it will be locked while the
+ // callback is running.
+ Callback(e *Entry)
+}
+
+// Entry represents a waiter that can be add to the a wait queue. It can
+// only be in one queue at a time, and is added "intrusively" to the queue with
+// no extra memory allocations.
+type Entry struct {
+ // Context stores any state the waiter may wish to store in the entry
+ // itself, which may be used at wake up time.
+ //
+ // Note that use of this field is optional and state may alternatively be
+ // stored in the callback itself.
+ Context interface{}
+
+ Callback EntryCallback
+
+ // The following fields are protected by the queue lock.
+ mask EventMask
+ ilist.Entry
+}
+
+type channelCallback struct{}
+
+// Callback implements EntryCallback.Callback.
+func (*channelCallback) Callback(e *Entry) {
+ ch := e.Context.(chan struct{})
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+}
+
+// NewChannelEntry initializes a new Entry that does a non-blocking write to a
+// struct{} channel when the callback is called. It returns the new Entry
+// instance and the channel being used.
+//
+// If a channel isn't specified (i.e., if "c" is nil), then NewChannelEntry
+// allocates a new channel.
+func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
+ if c == nil {
+ c = make(chan struct{}, 1)
+ }
+
+ return Entry{Context: c, Callback: &channelCallback{}}, c
+}
+
+// Queue represents the wait queue where waiters can be added and
+// notifiers can notify them when events happen.
+//
+// The zero value for waiter.Queue is an empty queue ready for use.
+type Queue struct {
+ list ilist.List `state:"zerovalue"`
+ mu sync.RWMutex `state:"nosave"`
+}
+
+// EventRegister adds a waiter to the wait queue; the waiter will be notified
+// when at least one of the events specified in mask happens.
+func (q *Queue) EventRegister(e *Entry, mask EventMask) {
+ q.mu.Lock()
+ e.mask = mask
+ q.list.PushBack(e)
+ q.mu.Unlock()
+}
+
+// EventUnregister removes the given waiter entry from the wait queue.
+func (q *Queue) EventUnregister(e *Entry) {
+ q.mu.Lock()
+ q.list.Remove(e)
+ q.mu.Unlock()
+}
+
+// Notify notifies all waiters in the queue whose masks have at least one bit
+// in common with the notification mask.
+func (q *Queue) Notify(mask EventMask) {
+ q.mu.RLock()
+ for it := q.list.Front(); it != nil; it = it.Next() {
+ e := it.(*Entry)
+ if (mask & e.mask) != 0 {
+ e.Callback.Callback(e)
+ }
+ }
+ q.mu.RUnlock()
+}
+
+// Events returns the set of events being waited on. It is the union of the
+// masks of all registered entries.
+func (q *Queue) Events() EventMask {
+ ret := EventMask(0)
+
+ q.mu.RLock()
+ for it := q.list.Front(); it != nil; it = it.Next() {
+ e := it.(*Entry)
+ ret |= e.mask
+ }
+ q.mu.RUnlock()
+
+ return ret
+}
+
+// IsEmpty returns if the wait queue is empty or not.
+func (q *Queue) IsEmpty() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.list.Front() == nil
+}
+
+// AlwaysReady implements the Waitable interface but is always ready. Embedding
+// this struct into another struct makes it implement the boilerplate empty
+// functions automatically.
+type AlwaysReady struct {
+}
+
+// Readiness always returns the input mask because this object is always ready.
+func (*AlwaysReady) Readiness(mask EventMask) EventMask {
+ return mask
+}
+
+// EventRegister doesn't do anything because this object doesn't need to issue
+// notifications because its readiness never changes.
+func (*AlwaysReady) EventRegister(*Entry, EventMask) {
+}
+
+// EventUnregister doesn't do anything because this object doesn't need to issue
+// notifications because its readiness never changes.
+func (*AlwaysReady) EventUnregister(e *Entry) {
+}
diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go
new file mode 100644
index 000000000..1a203350b
--- /dev/null
+++ b/pkg/waiter/waiter_test.go
@@ -0,0 +1,182 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package waiter
+
+import (
+ "sync/atomic"
+ "testing"
+)
+
+type callbackStub struct {
+ f func(e *Entry)
+}
+
+// Callback implements EntryCallback.Callback.
+func (c *callbackStub) Callback(e *Entry) {
+ c.f(e)
+}
+
+func TestEmptyQueue(t *testing.T) {
+ var q Queue
+
+ // Notify the zero-value of a queue.
+ q.Notify(EventIn)
+
+ // Register then unregister a waiter, then notify the queue.
+ cnt := 0
+ e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ q.EventRegister(&e, EventIn)
+ q.EventUnregister(&e)
+ q.Notify(EventIn)
+ if cnt != 0 {
+ t.Errorf("Callback was called when it shouldn't have been")
+ }
+}
+
+func TestMask(t *testing.T) {
+ // Register a waiter.
+ var q Queue
+ var cnt int
+ e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
+ q.EventRegister(&e, EventIn|EventErr)
+
+ // Notify with an overlapping mask.
+ cnt = 0
+ q.Notify(EventIn | EventOut)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with a subset mask.
+ cnt = 0
+ q.Notify(EventIn)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with a superset mask.
+ cnt = 0
+ q.Notify(EventIn | EventErr | EventOut)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with the exact same mask.
+ cnt = 0
+ q.Notify(EventIn | EventErr)
+ if cnt != 1 {
+ t.Errorf("Callback wasn't called when it should have been")
+ }
+
+ // Notify with a disjoint mask.
+ cnt = 0
+ q.Notify(EventOut | EventHUp)
+ if cnt != 0 {
+ t.Errorf("Callback was called when it shouldn't have been")
+ }
+}
+
+func TestConcurrentRegistration(t *testing.T) {
+ var q Queue
+ var cnt int
+ const concurrency = 1000
+
+ ch1 := make(chan struct{})
+ ch2 := make(chan struct{})
+ ch3 := make(chan struct{})
+
+ // Create goroutines that will all register/unregister concurrently.
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ var e Entry
+ e.Callback = &callbackStub{func(entry *Entry) {
+ cnt++
+ if entry != &e {
+ t.Errorf("entry = %p, want %p", entry, &e)
+ }
+ }}
+
+ // Wait for notification, then register.
+ <-ch1
+ q.EventRegister(&e, EventIn|EventErr)
+
+ // Tell main goroutine that we're done registering.
+ ch2 <- struct{}{}
+
+ // Wait for notification, then unregister.
+ <-ch3
+ q.EventUnregister(&e)
+
+ // Tell main goroutine that we're done unregistering.
+ ch2 <- struct{}{}
+ }()
+ }
+
+ // Let the goroutines register.
+ close(ch1)
+ for i := 0; i < concurrency; i++ {
+ <-ch2
+ }
+
+ // Issue a notification.
+ q.Notify(EventIn)
+ if cnt != concurrency {
+ t.Errorf("cnt = %d, want %d", cnt, concurrency)
+ }
+
+ // Let the goroutine unregister.
+ close(ch3)
+ for i := 0; i < concurrency; i++ {
+ <-ch2
+ }
+
+ // Issue a notification.
+ q.Notify(EventIn)
+ if cnt != concurrency {
+ t.Errorf("cnt = %d, want %d", cnt, concurrency)
+ }
+}
+
+func TestConcurrentNotification(t *testing.T) {
+ var q Queue
+ var cnt int32
+ const concurrency = 1000
+ const waiterCount = 1000
+
+ // Register waiters.
+ for i := 0; i < waiterCount; i++ {
+ var e Entry
+ e.Callback = &callbackStub{func(entry *Entry) {
+ atomic.AddInt32(&cnt, 1)
+ if entry != &e {
+ t.Errorf("entry = %p, want %p", entry, &e)
+ }
+ }}
+
+ q.EventRegister(&e, EventIn|EventErr)
+ }
+
+ // Launch notifiers.
+ ch1 := make(chan struct{})
+ ch2 := make(chan struct{})
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ <-ch1
+ q.Notify(EventIn)
+ ch2 <- struct{}{}
+ }()
+ }
+
+ // Let notifiers go.
+ close(ch1)
+ for i := 0; i < concurrency; i++ {
+ <-ch2
+ }
+
+ // Check the count.
+ if cnt != concurrency*waiterCount {
+ t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount)
+ }
+}