summaryrefslogtreecommitdiffhomepage
path: root/pkg/tmutex
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tmutex
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tmutex')
-rw-r--r--pkg/tmutex/BUILD17
-rw-r--r--pkg/tmutex/tmutex.go71
-rw-r--r--pkg/tmutex/tmutex_test.go247
3 files changed, 335 insertions, 0 deletions
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
new file mode 100644
index 000000000..5d1614d35
--- /dev/null
+++ b/pkg/tmutex/BUILD
@@ -0,0 +1,17 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "tmutex",
+ srcs = ["tmutex.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tmutex",
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "tmutex_test",
+ size = "medium",
+ srcs = ["tmutex_test.go"],
+ embed = [":tmutex"],
+)
diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go
new file mode 100644
index 000000000..61779654f
--- /dev/null
+++ b/pkg/tmutex/tmutex.go
@@ -0,0 +1,71 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tmutex provides the implementation of a mutex that implements an
+// efficient TryLock function in addition to Lock and Unlock.
+package tmutex
+
+import (
+ "sync/atomic"
+)
+
+// Mutex is a mutual exclusion primitive that implements TryLock in addition
+// to Lock and Unlock.
+type Mutex struct {
+ v int32
+ ch chan struct{}
+}
+
+// Init initializes the mutex.
+func (m *Mutex) Init() {
+ m.v = 1
+ m.ch = make(chan struct{}, 1)
+}
+
+// Lock acquires the mutex. If it is currently held by another goroutine, Lock
+// will wait until it has a chance to acquire it.
+func (m *Mutex) Lock() {
+ // Uncontended case.
+ if atomic.AddInt32(&m.v, -1) == 0 {
+ return
+ }
+
+ for {
+ // Try to acquire the mutex again, at the same time making sure
+ // that m.v is negative, which indicates to the owner of the
+ // lock that it is contended, which will force it to try to wake
+ // someone up when it releases the mutex.
+ if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
+ return
+ }
+
+ // Wait for the mutex to be released before trying again.
+ <-m.ch
+ }
+}
+
+// TryLock attempts to acquire the mutex without blocking. If the mutex is
+// currently held by another goroutine, it fails to acquire it and returns
+// false.
+func (m *Mutex) TryLock() bool {
+ v := atomic.LoadInt32(&m.v)
+ if v <= 0 {
+ return false
+ }
+ return atomic.CompareAndSwapInt32(&m.v, 1, 0)
+}
+
+// Unlock releases the mutex.
+func (m *Mutex) Unlock() {
+ if atomic.SwapInt32(&m.v, 1) == 0 {
+ // There were no pending waiters.
+ return
+ }
+
+ // Wake some waiter up.
+ select {
+ case m.ch <- struct{}{}:
+ default:
+ }
+}
diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go
new file mode 100644
index 000000000..e1b5fd4e2
--- /dev/null
+++ b/pkg/tmutex/tmutex_test.go
@@ -0,0 +1,247 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tmutex
+
+import (
+ "fmt"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestBasicLock(t *testing.T) {
+ var m Mutex
+ m.Init()
+
+ m.Lock()
+
+ // Try blocking lock the mutex from a different goroutine. This must
+ // not block because the mutex is held.
+ ch := make(chan struct{}, 1)
+ go func() {
+ m.Lock()
+ ch <- struct{}{}
+ m.Unlock()
+ ch <- struct{}{}
+ }()
+
+ select {
+ case <-ch:
+ t.Fatalf("Lock succeeded on locked mutex")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ // Unlock the mutex and make sure that the goroutine waiting on Lock()
+ // unblocks and succeeds.
+ m.Unlock()
+
+ select {
+ case <-ch:
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("Lock failed to acquire unlocked mutex")
+ }
+
+ // Make sure we can lock and unlock again.
+ m.Lock()
+ m.Unlock()
+}
+
+func TestTryLock(t *testing.T) {
+ var m Mutex
+ m.Init()
+
+ // Try to lock. It should succeed.
+ if !m.TryLock() {
+ t.Fatalf("TryLock failed on unlocked mutex")
+ }
+
+ // Try to lock again, it should now fail.
+ if m.TryLock() {
+ t.Fatalf("TryLock succeeded on locked mutex")
+ }
+
+ // Try blocking lock the mutex from a different goroutine. This must
+ // not block because the mutex is held.
+ ch := make(chan struct{}, 1)
+ go func() {
+ m.Lock()
+ ch <- struct{}{}
+ m.Unlock()
+ }()
+
+ select {
+ case <-ch:
+ t.Fatalf("Lock succeeded on locked mutex")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ // Unlock the mutex and make sure that the goroutine waiting on Lock()
+ // unblocks and succeeds.
+ m.Unlock()
+
+ select {
+ case <-ch:
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("Lock failed to acquire unlocked mutex")
+ }
+}
+
+func TestMutualExclusion(t *testing.T) {
+ var m Mutex
+ m.Init()
+
+ // Test mutual exclusion by running "gr" goroutines concurrently, and
+ // have each one increment a counter "iters" times within the critical
+ // section established by the mutex.
+ //
+ // If at the end the counter is not gr * iters, then we know that
+ // goroutines ran concurrently within the critical section.
+ //
+ // If one of the goroutines doesn't complete, it's likely a bug that
+ // causes to it to wait forever.
+ const gr = 1000
+ const iters = 100000
+ v := 0
+ var wg sync.WaitGroup
+ for i := 0; i < gr; i++ {
+ wg.Add(1)
+ go func() {
+ for j := 0; j < iters; j++ {
+ m.Lock()
+ v++
+ m.Unlock()
+ }
+ wg.Done()
+ }()
+ }
+
+ wg.Wait()
+
+ if v != gr*iters {
+ t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
+ }
+}
+
+func TestMutualExclusionWithTryLock(t *testing.T) {
+ var m Mutex
+ m.Init()
+
+ // Similar to the previous, with the addition of some goroutines that
+ // only increment the count if TryLock succeeds.
+ const gr = 1000
+ const iters = 100000
+ total := int64(gr * iters)
+ var tryTotal int64
+ v := int64(0)
+ var wg sync.WaitGroup
+ for i := 0; i < gr; i++ {
+ wg.Add(2)
+ go func() {
+ for j := 0; j < iters; j++ {
+ m.Lock()
+ v++
+ m.Unlock()
+ }
+ wg.Done()
+ }()
+ go func() {
+ local := int64(0)
+ for j := 0; j < iters; j++ {
+ if m.TryLock() {
+ v++
+ m.Unlock()
+ local++
+ }
+ }
+ atomic.AddInt64(&tryTotal, local)
+ wg.Done()
+ }()
+ }
+
+ wg.Wait()
+
+ t.Logf("tryTotal = %d", tryTotal)
+ total += tryTotal
+
+ if v != total {
+ t.Fatalf("Bad count: got %v, want %v", v, total)
+ }
+}
+
+// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following
+// differences:
+//
+// - The number of goroutines is variable, with the maximum value depending on
+// GOMAXPROCS.
+//
+// - The number of iterations per benchmark is controlled by the benchmarking
+// framework.
+//
+// - Care is taken to ensure that all goroutines participating in the benchmark
+// have been created before the benchmark begins.
+func BenchmarkTmutex(b *testing.B) {
+ for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var m Mutex
+ m.Init()
+
+ var ready sync.WaitGroup
+ begin := make(chan struct{})
+ var end sync.WaitGroup
+ for i := 0; i < n; i++ {
+ ready.Add(1)
+ end.Add(1)
+ go func() {
+ ready.Done()
+ <-begin
+ for j := 0; j < b.N; j++ {
+ m.Lock()
+ m.Unlock()
+ }
+ end.Done()
+ }()
+ }
+
+ ready.Wait()
+ b.ResetTimer()
+ close(begin)
+ end.Wait()
+ })
+ }
+}
+
+// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as
+// a comparison point.
+func BenchmarkSyncMutex(b *testing.B) {
+ for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var m sync.Mutex
+
+ var ready sync.WaitGroup
+ begin := make(chan struct{})
+ var end sync.WaitGroup
+ for i := 0; i < n; i++ {
+ ready.Add(1)
+ end.Add(1)
+ go func() {
+ ready.Done()
+ <-begin
+ for j := 0; j < b.N; j++ {
+ m.Lock()
+ m.Unlock()
+ }
+ end.Done()
+ }()
+ }
+
+ ready.Wait()
+ b.ResetTimer()
+ close(begin)
+ end.Wait()
+ })
+ }
+}