summaryrefslogtreecommitdiffhomepage
path: root/pkg/gate
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/gate')
-rw-r--r--pkg/gate/BUILD22
-rw-r--r--pkg/gate/gate.go120
-rw-r--r--pkg/gate/gate_test.go175
3 files changed, 317 insertions, 0 deletions
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
new file mode 100644
index 000000000..381474d9e
--- /dev/null
+++ b/pkg/gate/BUILD
@@ -0,0 +1,22 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "gate",
+ srcs = [
+ "gate.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/gate",
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "gate_test",
+ srcs = [
+ "gate_test.go",
+ ],
+ deps = [
+ ":gate",
+ ],
+)
diff --git a/pkg/gate/gate.go b/pkg/gate/gate.go
new file mode 100644
index 000000000..47651c761
--- /dev/null
+++ b/pkg/gate/gate.go
@@ -0,0 +1,120 @@
+// Package gate provides a usage Gate synchronization primitive.
+package gate
+
+import (
+ "sync/atomic"
+)
+
+const (
+ // gateClosed is the bit set in the gate's user count to indicate that
+ // it has been closed. It is the MSB of the 32-bit field; the other 31
+ // bits carry the actual count.
+ gateClosed = 0x80000000
+)
+
+// Gate is a synchronization primitive that allows concurrent goroutines to
+// "enter" it as long as it hasn't been closed yet. Once it's been closed,
+// goroutines cannot enter it anymore, but are allowed to leave, and the closer
+// will be informed when all goroutines have left.
+//
+// Many goroutines are allowed to enter the gate concurrently, but only one is
+// allowed to close it.
+//
+// This is similar to a r/w critical section, except that goroutines "entering"
+// never block: they either enter immediately or fail to enter. The closer will
+// block waiting for all goroutines currently inside the gate to leave.
+//
+// This function is implemented efficiently. On x86, only one interlocked
+// operation is performed on enter, and one on leave.
+//
+// This is useful, for example, in cases when a goroutine is trying to clean up
+// an object for which multiple goroutines have pointers. In such a case, users
+// would be required to enter and leave the gates, and the cleaner would wait
+// until all users are gone (and no new ones are allowed) before proceeding.
+//
+// Users:
+//
+// if !g.Enter() {
+// // Gate is closed, we can't use the object.
+// return
+// }
+//
+// // Do something with object.
+// [...]
+//
+// g.Leave()
+//
+// Closer:
+//
+// // Prevent new users from using the object, and wait for the existing
+// // ones to complete.
+// g.Close()
+//
+// // Clean up the object.
+// [...]
+//
+type Gate struct {
+ userCount uint32
+ done chan struct{}
+}
+
+// Enter tries to enter the gate. It will succeed if it hasn't been closed yet,
+// in which case the caller must eventually call Leave().
+//
+// This function is thread-safe.
+func (g *Gate) Enter() bool {
+ if g == nil {
+ return false
+ }
+
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&gateClosed != 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v+1) {
+ return true
+ }
+ }
+}
+
+// Leave leaves the gate. This must only be called after a successful call to
+// Enter(). If the gate has been closed and this is the last one inside the
+// gate, it will notify the closer that the gate is done.
+//
+// This function is thread-safe.
+func (g *Gate) Leave() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed == 0 {
+ panic("leaving a gate with zero usage count")
+ }
+
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v-1) {
+ if v == gateClosed+1 {
+ close(g.done)
+ }
+ return
+ }
+ }
+}
+
+// Close closes the gate for entering, and waits until all goroutines [that are
+// currently inside the gate] leave before returning.
+//
+// Only one goroutine can call this function.
+func (g *Gate) Close() {
+ for {
+ v := atomic.LoadUint32(&g.userCount)
+ if v&^gateClosed != 0 && g.done == nil {
+ g.done = make(chan struct{})
+ }
+ if atomic.CompareAndSwapUint32(&g.userCount, v, v|gateClosed) {
+ if v&^gateClosed != 0 {
+ <-g.done
+ }
+ return
+ }
+ }
+}
diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go
new file mode 100644
index 000000000..b3b101a0c
--- /dev/null
+++ b/pkg/gate/gate_test.go
@@ -0,0 +1,175 @@
+package gate_test
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/gate"
+)
+
+func TestBasicEnter(t *testing.T) {
+ var g gate.Gate
+
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+
+ g.Leave()
+
+ g.Close()
+
+ if g.Enter() {
+ t.Fatalf("Allowed to enter when it should fail")
+ }
+}
+
+func enterFunc(t *testing.T, g *gate.Gate, enter, leave, reenter chan struct{}, done1, done2, done3 *sync.WaitGroup) {
+ // Wait until instructed to enter.
+ <-enter
+ if !g.Enter() {
+ t.Errorf("Failed to enter when it should be allowed")
+ }
+
+ done1.Done()
+
+ // Wait until instructed to leave.
+ <-leave
+ g.Leave()
+
+ done2.Done()
+
+ // Wait until instructed to reenter.
+ <-reenter
+ if g.Enter() {
+ t.Errorf("Allowed to enter when it should fail")
+ }
+ done3.Done()
+}
+
+func TestConcurrentEnter(t *testing.T) {
+ var g gate.Gate
+ var done1, done2, done3 sync.WaitGroup
+
+ // Create 1000 worker goroutines.
+ enter := make(chan struct{})
+ leave := make(chan struct{})
+ reenter := make(chan struct{})
+ done1.Add(1000)
+ done2.Add(1000)
+ done3.Add(1000)
+ for i := 0; i < 1000; i++ {
+ go enterFunc(t, &g, enter, leave, reenter, &done1, &done2, &done3)
+ }
+
+ // Tell them all to enter, then leave.
+ close(enter)
+ done1.Wait()
+
+ close(leave)
+ done2.Wait()
+
+ // Close the gate, then have the workers try to enter again.
+ g.Close()
+ close(reenter)
+ done3.Wait()
+}
+
+func closeFunc(g *gate.Gate, done chan struct{}) {
+ g.Close()
+ close(done)
+}
+
+func TestCloseWaits(t *testing.T) {
+ var g gate.Gate
+
+ // Enter 10 times.
+ for i := 0; i < 10; i++ {
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+ }
+
+ // Launch closer. Check that it doesn't complete.
+ done := make(chan struct{})
+ go closeFunc(&g, done)
+
+ for i := 0; i < 10; i++ {
+ select {
+ case <-done:
+ t.Fatalf("Close function completed too soon")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ g.Leave()
+ }
+
+ // Now the closer must complete.
+ <-done
+}
+
+func TestMultipleSerialCloses(t *testing.T) {
+ var g gate.Gate
+
+ // Enter 10 times.
+ for i := 0; i < 10; i++ {
+ if !g.Enter() {
+ t.Fatalf("Failed to enter when it should be allowed")
+ }
+ }
+
+ // Launch closer. Check that it doesn't complete.
+ done := make(chan struct{})
+ go closeFunc(&g, done)
+
+ for i := 0; i < 10; i++ {
+ select {
+ case <-done:
+ t.Fatalf("Close function completed too soon")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ g.Leave()
+ }
+
+ // Now the closer must complete.
+ <-done
+
+ // Close again should not block.
+ done = make(chan struct{})
+ go closeFunc(&g, done)
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatalf("Second Close is blocking")
+ }
+}
+
+func worker(g *gate.Gate, done *sync.WaitGroup) {
+ for {
+ if !g.Enter() {
+ break
+ }
+ g.Leave()
+ }
+ done.Done()
+}
+
+func TestConcurrentAll(t *testing.T) {
+ var g gate.Gate
+ var done sync.WaitGroup
+
+ // Launch 1000 goroutines to concurrently enter/leave.
+ done.Add(1000)
+ for i := 0; i < 1000; i++ {
+ go worker(&g, &done)
+ }
+
+ // Wait for the goroutines to do some work, then close the gate.
+ time.Sleep(2 * time.Second)
+ g.Close()
+
+ // Wait for all of them to complete.
+ done.Wait()
+}