summaryrefslogtreecommitdiffhomepage
path: root/pkg/sync/atomicptrmap
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sync/atomicptrmap')
-rw-r--r--pkg/sync/atomicptrmap/BUILD76
-rw-r--r--pkg/sync/atomicptrmap/atomicptrmap.go20
-rw-r--r--pkg/sync/atomicptrmap/atomicptrmap_test.go635
-rw-r--r--pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go500
4 files changed, 1231 insertions, 0 deletions
diff --git a/pkg/sync/atomicptrmap/BUILD b/pkg/sync/atomicptrmap/BUILD
new file mode 100644
index 000000000..b0e218c79
--- /dev/null
+++ b/pkg/sync/atomicptrmap/BUILD
@@ -0,0 +1,76 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:private"],
+ licenses = ["notice"],
+)
+
+go_template(
+ name = "generic_atomicptrmap",
+ srcs = ["generic_atomicptrmap_unsafe.go"],
+ opt_consts = [
+ "ShardOrder",
+ ],
+ opt_types = [
+ "Hasher",
+ ],
+ types = [
+ "Key",
+ "Value",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ "//pkg/sync",
+ ],
+)
+
+go_template_instance(
+ name = "test_atomicptrmap",
+ out = "test_atomicptrmap_unsafe.go",
+ package = "atomicptrmap",
+ prefix = "test",
+ template = ":generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_template_instance(
+ name = "test_atomicptrmap_sharded",
+ out = "test_atomicptrmap_sharded_unsafe.go",
+ consts = {
+ "ShardOrder": "4",
+ },
+ package = "atomicptrmap",
+ prefix = "test",
+ suffix = "Sharded",
+ template = ":generic_atomicptrmap",
+ types = {
+ "Key": "int64",
+ "Value": "testValue",
+ },
+)
+
+go_library(
+ name = "atomicptrmap",
+ testonly = 1,
+ srcs = [
+ "atomicptrmap.go",
+ "test_atomicptrmap_sharded_unsafe.go",
+ "test_atomicptrmap_unsafe.go",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "atomicptrmap_test",
+ size = "small",
+ srcs = ["atomicptrmap_test.go"],
+ library = ":atomicptrmap",
+ deps = ["//pkg/sync"],
+)
diff --git a/pkg/sync/atomicptrmap/atomicptrmap.go b/pkg/sync/atomicptrmap/atomicptrmap.go
new file mode 100644
index 000000000..867821ce9
--- /dev/null
+++ b/pkg/sync/atomicptrmap/atomicptrmap.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package atomicptrmap instantiates generic_atomicptrmap for testing.
+package atomicptrmap
+
+type testValue struct {
+ val int
+}
diff --git a/pkg/sync/atomicptrmap/atomicptrmap_test.go b/pkg/sync/atomicptrmap/atomicptrmap_test.go
new file mode 100644
index 000000000..75a9997ef
--- /dev/null
+++ b/pkg/sync/atomicptrmap/atomicptrmap_test.go
@@ -0,0 +1,635 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package atomicptrmap
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestConsistencyWithGoMap(t *testing.T) {
+ const maxKey = 16
+ var vals [4]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ vals[i] = new(testValue)
+ }
+ var (
+ m = make(map[int64]*testValue)
+ apm testAtomicPtrMap
+ )
+ for i := 0; i < 100000; i++ {
+ // Apply a random operation to both m and apm and expect them to have
+ // the same result. Bias toward CompareAndSwap, which has the most
+ // cases; bias away from Range and RangeRepeatable, which are
+ // relatively expensive.
+ switch rand.Intn(10) {
+ case 0, 1: // Load
+ key := rand.Int63n(maxKey)
+ want := m[key]
+ got := apm.Load(key)
+ t.Logf("Load(%d) = %p", key, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 2, 3: // Swap
+ key := rand.Int63n(maxKey)
+ val := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if val != nil {
+ m[key] = val
+ } else {
+ delete(m, key)
+ }
+ got := apm.Swap(key, val)
+ t.Logf("Swap(%d, %p) = %p", key, val, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 4, 5, 6, 7: // CompareAndSwap
+ key := rand.Int63n(maxKey)
+ oldVal := vals[rand.Intn(len(vals))]
+ newVal := vals[rand.Intn(len(vals))]
+ want := m[key]
+ if want == oldVal {
+ if newVal != nil {
+ m[key] = newVal
+ } else {
+ delete(m, key)
+ }
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got)
+ if got != want {
+ t.Fatalf("got %p, wanted %p", got, want)
+ }
+ case 8: // Range
+ got := make(map[int64]*testValue)
+ var (
+ haveDup = false
+ dup int64
+ )
+ apm.Range(func(key int64, val *testValue) bool {
+ if _, ok := got[key]; ok && !haveDup {
+ haveDup = true
+ dup = key
+ }
+ got[key] = val
+ return true
+ })
+ t.Logf("Range() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ if haveDup {
+ t.Fatalf("got duplicate key %d", dup)
+ }
+ case 9: // RangeRepeatable
+ got := make(map[int64]*testValue)
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ got[key] = val
+ return true
+ })
+ t.Logf("RangeRepeatable() = %v", got)
+ if !reflect.DeepEqual(got, m) {
+ t.Fatalf("got %v, wanted %v", got, m)
+ }
+ }
+ }
+}
+
+func TestConcurrentHeterogeneous(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ var (
+ apm testAtomicPtrMap
+ wg sync.WaitGroup
+ )
+ defer func() {
+ cancel()
+ wg.Wait()
+ }()
+
+ possibleKeyValuePairs := make(map[int64]map[*testValue]struct{})
+ addKeyValuePair := func(key int64, val *testValue) {
+ values := possibleKeyValuePairs[key]
+ if values == nil {
+ values = make(map[*testValue]struct{})
+ possibleKeyValuePairs[key] = values
+ }
+ values[val] = struct{}{}
+ }
+
+ const numValuesPerKey = 4
+
+ // These goroutines use keys not used by any other goroutine.
+ const numPrivateKeys = 3
+ for i := 0; i < numPrivateKeys; i++ {
+ key := int64(i)
+ var vals [numValuesPerKey]*testValue
+ for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
+ val := new(testValue)
+ vals[i] = val
+ addKeyValuePair(key, val)
+ }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ var stored *testValue
+ for ctx.Err() == nil {
+ switch r.Intn(4) {
+ case 0:
+ got := apm.Load(key)
+ if got != stored {
+ t.Errorf("Load(%d): got %p, wanted %p", key, got, stored)
+ return
+ }
+ case 1:
+ val := vals[r.Intn(len(vals))]
+ want := stored
+ stored = val
+ got := apm.Swap(key, val)
+ if got != want {
+ t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want)
+ return
+ }
+ case 2, 3:
+ oldVal := vals[r.Intn(len(vals))]
+ newVal := vals[r.Intn(len(vals))]
+ want := stored
+ if stored == oldVal {
+ stored = newVal
+ }
+ got := apm.CompareAndSwap(key, oldVal, newVal)
+ if got != want {
+ t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // These goroutines share a small set of keys.
+ const numSharedKeys = 2
+ var (
+ sharedKeys [numSharedKeys]int64
+ sharedValues = make(map[int64][]*testValue)
+ sharedValuesSet = make(map[int64]map[*testValue]struct{})
+ )
+ for i := range sharedKeys {
+ key := int64(numPrivateKeys + i)
+ sharedKeys[i] = key
+ vals := make([]*testValue, numValuesPerKey)
+ valsSet := make(map[*testValue]struct{})
+ for j := range vals {
+ val := new(testValue)
+ vals[j] = val
+ valsSet[val] = struct{}{}
+ addKeyValuePair(key, val)
+ }
+ sharedValues[key] = vals
+ sharedValuesSet[key] = valsSet
+ }
+ randSharedValue := func(r *rand.Rand, key int64) *testValue {
+ vals := sharedValues[key]
+ return vals[r.Intn(len(vals))]
+ }
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ keyIndex := r.Intn(len(sharedKeys))
+ key := sharedKeys[keyIndex]
+ var (
+ op string
+ got *testValue
+ )
+ switch r.Intn(4) {
+ case 0:
+ op = "Load"
+ got = apm.Load(key)
+ case 1:
+ op = "Swap"
+ got = apm.Swap(key, randSharedValue(r, key))
+ case 2, 3:
+ op = "CompareAndSwap"
+ got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key))
+ }
+ if got != nil {
+ valsSet := sharedValuesSet[key]
+ if _, ok := valsSet[got]; !ok {
+ t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet)
+ return
+ }
+ }
+ }
+ }()
+ }
+
+ // This goroutine repeatedly searches for unused keys.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r := rand.New(rand.NewSource(rand.Int63()))
+ for ctx.Err() == nil {
+ key := -1 - r.Int63()
+ if got := apm.Load(key); got != nil {
+ t.Errorf("Load(%d): got %p, wanted nil", key, got)
+ }
+ }
+ }()
+
+ // This goroutine repeatedly calls RangeRepeatable() and checks that each
+ // key corresponds to an expected value.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ abort := false
+ for !abort && ctx.Err() == nil {
+ apm.RangeRepeatable(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("RangeRepeatable: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ return true
+ })
+ }
+ }()
+
+ // Finally, the main goroutine spins for the length of the test calling
+ // Range() and checking that each key that it observes is unique and
+ // corresponds to an expected value.
+ seenKeys := make(map[int64]struct{})
+ const testDuration = 5 * time.Second
+ end := time.Now().Add(testDuration)
+ abort := false
+ for time.Now().Before(end) {
+ apm.Range(func(key int64, val *testValue) bool {
+ values, ok := possibleKeyValuePairs[key]
+ if !ok {
+ t.Errorf("Range: got invalid key %d", key)
+ abort = true
+ return false
+ }
+ if _, ok := values[val]; !ok {
+ t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values)
+ abort = true
+ return false
+ }
+ if _, ok := seenKeys[key]; ok {
+ t.Errorf("Range: got duplicate key %d", key)
+ abort = true
+ return false
+ }
+ seenKeys[key] = struct{}{}
+ return true
+ })
+ if abort {
+ break
+ }
+ for k := range seenKeys {
+ delete(seenKeys, k)
+ }
+ }
+}
+
+type benchmarkableMap interface {
+ Load(key int64) *testValue
+ Store(key int64, val *testValue)
+ LoadOrStore(key int64, val *testValue) (*testValue, bool)
+ Delete(key int64)
+}
+
+// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map.
+type rwMutexMap struct {
+ mu sync.RWMutex
+ m map[int64]*testValue
+}
+
+func (m *rwMutexMap) Load(key int64) *testValue {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.m[key]
+}
+
+func (m *rwMutexMap) Store(key int64, val *testValue) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ m.m[key] = val
+}
+
+func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.m == nil {
+ m.m = make(map[int64]*testValue)
+ }
+ if oldVal, ok := m.m[key]; ok {
+ return oldVal, true
+ }
+ m.m[key] = val
+ return val, false
+}
+
+func (m *rwMutexMap) Delete(key int64) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ delete(m.m, key)
+}
+
+// syncMap implements benchmarkableMap for a sync.Map.
+type syncMap struct {
+ m sync.Map
+}
+
+func (m *syncMap) Load(key int64) *testValue {
+ val, ok := m.m.Load(key)
+ if !ok {
+ return nil
+ }
+ return val.(*testValue)
+}
+
+func (m *syncMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ actual, loaded := m.m.LoadOrStore(key, val)
+ return actual.(*testValue), loaded
+}
+
+func (m *syncMap) Delete(key int64) {
+ m.m.Delete(key)
+}
+
+// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap.
+type benchmarkableAtomicPtrMap struct {
+ m testAtomicPtrMap
+}
+
+func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMap) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded.
+type benchmarkableAtomicPtrMapSharded struct {
+ m testAtomicPtrMapSharded
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue {
+ return m.m.Load(key)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) {
+ m.m.Store(key, val)
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
+ if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
+ return prev, true
+ }
+ return val, false
+}
+
+func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) {
+ m.m.Store(key, nil)
+}
+
+var mapImpls = [...]struct {
+ name string
+ ctor func() benchmarkableMap
+}{
+ {
+ name: "RWMutexMap",
+ ctor: func() benchmarkableMap {
+ return new(rwMutexMap)
+ },
+ },
+ {
+ name: "SyncMap",
+ ctor: func() benchmarkableMap {
+ return new(syncMap)
+ },
+ },
+ {
+ name: "AtomicPtrMap",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMap)
+ },
+ },
+ {
+ name: "AtomicPtrMapSharded",
+ ctor: func() benchmarkableMap {
+ return new(benchmarkableAtomicPtrMapSharded)
+ },
+ },
+}
+
+func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.LoadOrStore(int64(i), val)
+ }
+ for i := 0; i < b.N; i++ {
+ m.Delete(int64(i))
+ }
+}
+
+func BenchmarkLoadOrStoreDelete(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLoadOrStoreDelete(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(i))
+ }
+}
+
+func BenchmarkLookupPositive(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupPositive(b, mapImpl.ctor)
+ })
+ }
+}
+
+func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) {
+ m := mapCtor()
+ val := &testValue{}
+ for i := 0; i < b.N; i++ {
+ m.Store(int64(i), val)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ m.Load(int64(-1 - i))
+ }
+}
+
+func BenchmarkLookupNegative(b *testing.B) {
+ for _, mapImpl := range mapImpls {
+ b.Run(mapImpl.name, func(b *testing.B) {
+ benchmarkLookupNegative(b, mapImpl.ctor)
+ })
+ }
+}
+
+type benchmarkConcurrentOptions struct {
+ // loadsPerMutationPair is the number of map lookups between each
+ // insertion/deletion pair.
+ loadsPerMutationPair int
+
+ // If changeKeys is true, the keys used by each goroutine change between
+ // iterations of the test.
+ changeKeys bool
+}
+
+func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) {
+ var (
+ started sync.WaitGroup
+ workers sync.WaitGroup
+ )
+ started.Add(1)
+
+ m := mapCtor()
+ val := &testValue{}
+ // Insert a large number of unused elements into the map so that used
+ // elements are distributed throughout memory.
+ for i := 0; i < 10000; i++ {
+ m.Store(int64(-1-i), val)
+ }
+ // n := ceil(b.N / (opts.loadsPerMutationPair + 2))
+ n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2)
+ for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ {
+ workerID := i
+ workers.Add(1)
+ go func() {
+ defer workers.Done()
+ started.Wait()
+ for i := 0; i < n; i++ {
+ var key int64
+ if opts.changeKeys {
+ key = int64(workerID*n + i)
+ } else {
+ key = int64(workerID)
+ }
+ m.LoadOrStore(key, val)
+ for j := 0; j < opts.loadsPerMutationPair; j++ {
+ m.Load(key)
+ }
+ m.Delete(key)
+ }
+ }()
+ }
+
+ b.ResetTimer()
+ started.Done()
+ workers.Wait()
+}
+
+func BenchmarkConcurrent(b *testing.B) {
+ changeKeysChoices := [...]struct {
+ name string
+ val bool
+ }{
+ {"FixedKeys", false},
+ {"ChangingKeys", true},
+ }
+ writePcts := [...]struct {
+ name string
+ loadsPerMutationPair int
+ }{
+ {"1PercentWrites", 198},
+ {"10PercentWrites", 18},
+ {"50PercentWrites", 2},
+ }
+ for _, changeKeys := range changeKeysChoices {
+ for _, writePct := range writePcts {
+ for _, mapImpl := range mapImpls {
+ name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name)
+ b.Run(name, func(b *testing.B) {
+ benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{
+ loadsPerMutationPair: writePct.loadsPerMutationPair,
+ changeKeys: changeKeys.val,
+ })
+ })
+ }
+ }
+ }
+}
diff --git a/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go
new file mode 100644
index 000000000..3e98cb309
--- /dev/null
+++ b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go
@@ -0,0 +1,500 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package atomicptrmap doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package atomicptrmap
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Key is a required type parameter.
+type Key struct{}
+
+// Value is a required type parameter.
+type Value struct{}
+
+const (
+ // ShardOrder is an optional parameter specifying the base-2 log of the
+ // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce
+ // unnecessary synchronization between unrelated concurrent operations,
+ // improving performance for write-heavy workloads, but increase memory
+ // usage for small maps.
+ ShardOrder = 0
+)
+
+// Hasher is an optional type parameter. If Hasher is provided, it must define
+// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps.
+type Hasher struct {
+ defaultHasher
+}
+
+// defaultHasher is the default Hasher. This indirection exists because
+// defaultHasher must exist even if a custom Hasher is provided, to prevent the
+// Go compiler from complaining about defaultHasher's unused imports.
+type defaultHasher struct {
+ fn func(unsafe.Pointer, uintptr) uintptr
+ seed uintptr
+}
+
+// Init initializes the Hasher.
+func (h *defaultHasher) Init() {
+ h.fn = sync.MapKeyHasher(map[Key]*Value(nil))
+ h.seed = sync.RandUintptr()
+}
+
+// Hash returns the hash value for the given Key.
+func (h *defaultHasher) Hash(key Key) uintptr {
+ return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed)
+}
+
+var hasher Hasher
+
+func init() {
+ hasher.Init()
+}
+
+// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are
+// safe for concurrent use from multiple goroutines without additional
+// synchronization.
+//
+// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for
+// use. AtomicPtrMaps must not be copied after first use.
+//
+// sync.Map may be faster than AtomicPtrMap if most operations on the map are
+// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in
+// other circumstances.
+type AtomicPtrMap struct {
+ // AtomicPtrMap is implemented as a hash table with the following
+ // properties:
+ //
+ // * Collisions are resolved with quadratic probing. Of the two major
+ // alternatives, Robin Hood linear probing makes it difficult for writers
+ // to execute in parallel, and bucketing is less effective in Go due to
+ // lack of SIMD.
+ //
+ // * The table is optionally divided into shards indexed by hash to further
+ // reduce unnecessary synchronization.
+
+ shards [1 << ShardOrder]apmShard
+}
+
+func (m *AtomicPtrMap) shard(hash uintptr) *apmShard {
+ // Go defines right shifts >= width of shifted unsigned operand as 0, so
+ // this is correct even if ShardOrder is 0 (although nogo complains because
+ // nogo is dumb).
+ const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder
+ index := hash >> indexLSB
+ return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{}))))
+}
+
+type apmShard struct {
+ apmShardMutationData
+ _ [apmShardMutationDataPadding]byte
+ apmShardLookupData
+ _ [apmShardLookupDataPadding]byte
+}
+
+type apmShardMutationData struct {
+ dirtyMu sync.Mutex // serializes slot transitions out of empty
+ dirty uintptr // # slots with val != nil
+ count uintptr // # slots with val != nil and val != tombstone()
+ rehashMu sync.Mutex // serializes rehashing
+}
+
+type apmShardLookupData struct {
+ seq sync.SeqCount // allows atomic reads of slots+mask
+ slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq
+ mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq
+}
+
+const (
+ cacheLineBytes = 64
+ // Cache line padding is enabled if sharding is.
+ apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise
+ // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) %
+ // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes).
+ apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1)
+ apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding
+ apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1)
+ apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding
+
+ // These define fractional thresholds for when apmShard.rehash() is called
+ // (i.e. the load factor) and when it rehases to a larger table
+ // respectively. They are chosen such that the rehash threshold = the
+ // expansion threshold + 1/2, so that when reuse of deleted slots is rare
+ // or non-existent, rehashing occurs after the insertion of at least 1/2
+ // the table's size in new entries, which is acceptably infrequent.
+ apmRehashThresholdNum = 2
+ apmRehashThresholdDen = 3
+ apmExpansionThresholdNum = 1
+ apmExpansionThresholdDen = 6
+)
+
+type apmSlot struct {
+ // slot states are indicated by val:
+ //
+ // * Empty: val == nil; key is meaningless. May transition to full or
+ // evacuated with dirtyMu locked.
+ //
+ // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val
+ // is the Value mapped to key. May transition to deleted or evacuated.
+ //
+ // * Deleted: val == tombstone(); key is still immutable. key is mapped to
+ // no Value. May transition to full or evacuated.
+ //
+ // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on
+ // slots that have already been moved, requiring readers to wait for
+ // rehashing to complete and use the new table. Terminal state.
+ //
+ // Note that once val is non-nil, it cannot become nil again. That is, the
+ // transition from empty to non-empty is irreversible for a given slot;
+ // the only way to create more empty slots is by rehashing.
+ val unsafe.Pointer
+ key Key
+}
+
+func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot {
+ return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{})))
+}
+
+var tombstoneObj byte
+
+func tombstone() unsafe.Pointer {
+ return unsafe.Pointer(&tombstoneObj)
+}
+
+var evacuatedObj byte
+
+func evacuated() unsafe.Pointer {
+ return unsafe.Pointer(&evacuatedObj)
+}
+
+// Load returns the Value stored in m for key.
+func (m *AtomicPtrMap) Load(key Key) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ return nil
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Empty slot; end of probe sequence.
+ return nil
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// Store stores the Value val for key.
+func (m *AtomicPtrMap) Store(key Key, val *Value) {
+ m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// Swap stores the Value val for key and returns the previously-mapped Value.
+func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value {
+ return m.maybeCompareAndSwap(key, false, nil, val)
+}
+
+// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it
+// stores the Value newVal for key. CompareAndSwap returns the previous Value
+// stored for key, whether or not it stores newVal.
+func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value {
+ return m.maybeCompareAndSwap(key, true, oldVal, newVal)
+}
+
+func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value {
+ hash := hasher.Hash(key)
+ shard := m.shard(hash)
+ oldVal := tombstone()
+ if typedOldVal != nil {
+ oldVal = unsafe.Pointer(typedOldVal)
+ }
+ newVal := tombstone()
+ if typedNewVal != nil {
+ newVal = unsafe.Pointer(typedNewVal)
+ }
+
+retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Need to allocate a table before insertion.
+ shard.rehash(nil)
+ goto retry
+ }
+
+ i := hash & mask
+ inc := uintptr(1)
+ for {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ if (compare && oldVal != tombstone()) || newVal == tombstone() {
+ return nil
+ }
+ // Try to grab this slot for ourselves.
+ shard.dirtyMu.Lock()
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == nil {
+ // Check if we need to rehash before dirtying a slot.
+ if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum {
+ shard.dirtyMu.Unlock()
+ shard.rehash(slots)
+ goto retry
+ }
+ slot.key = key
+ atomic.StorePointer(&slot.val, newVal) // transitions slot to full
+ shard.dirty++
+ atomic.AddUintptr(&shard.count, 1)
+ shard.dirtyMu.Unlock()
+ return nil
+ }
+ // Raced with another store; the slot is no longer empty. Continue
+ // with the new value of slotVal since we may have raced with
+ // another store of key.
+ shard.dirtyMu.Unlock()
+ }
+ if slotVal == evacuated() {
+ // Racing with rehashing.
+ goto retry
+ }
+ if slot.key == key {
+ // We're reusing an existing slot, so rehashing isn't necessary.
+ for {
+ if (compare && oldVal != slotVal) || newVal == slotVal {
+ if slotVal == tombstone() {
+ return nil
+ }
+ return (*Value)(slotVal)
+ }
+ if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) {
+ if slotVal == tombstone() {
+ atomic.AddUintptr(&shard.count, 1)
+ return nil
+ }
+ if newVal == tombstone() {
+ atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */)
+ }
+ return (*Value)(slotVal)
+ }
+ slotVal = atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ }
+ }
+ // This produces a triangular number sequence of offsets from the
+ // initially-probed position.
+ i = (i + inc) & mask
+ inc++
+ }
+}
+
+// rehash is marked nosplit to avoid preemption during table copying.
+//go:nosplit
+func (shard *apmShard) rehash(oldSlots unsafe.Pointer) {
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+
+ if shard.slots != oldSlots {
+ // Raced with another call to rehash().
+ return
+ }
+
+ // Determine the size of the new table. Constraints:
+ //
+ // * The size of the table must be a power of two to ensure that every slot
+ // is visitable by every probe sequence under quadratic probing with
+ // triangular numbers.
+ //
+ // * The size of the table cannot decrease because even if shard.count is
+ // currently smaller than shard.dirty, concurrent stores that reuse
+ // existing slots can drive shard.count back up to a maximum of
+ // shard.dirty.
+ newSize := uintptr(8) // arbitrary initial size
+ if oldSlots != nil {
+ oldSize := shard.mask + 1
+ newSize = oldSize
+ if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum {
+ newSize *= 2
+ }
+ }
+
+ // Allocate the new table.
+ newSlotsSlice := make([]apmSlot, newSize)
+ newSlotsHeader := (*gohacks.SliceHeader)(unsafe.Pointer(&newSlotsSlice))
+ newSlots := newSlotsHeader.Data
+ newMask := newSize - 1
+
+ // Start a writer critical section now so that racing users of the old
+ // table that observe evacuated() wait for the new table. (But lock dirtyMu
+ // first since doing so may block, which we don't want to do during the
+ // writer critical section.)
+ shard.dirtyMu.Lock()
+ shard.seq.BeginWrite()
+
+ if oldSlots != nil {
+ realCount := uintptr(0)
+ // Copy old entries to the new table.
+ oldMask := shard.mask
+ for i := uintptr(0); i <= oldMask; i++ {
+ oldSlot := apmSlotAt(oldSlots, i)
+ val := atomic.SwapPointer(&oldSlot.val, evacuated())
+ if val == nil || val == tombstone() {
+ continue
+ }
+ hash := hasher.Hash(oldSlot.key)
+ j := hash & newMask
+ inc := uintptr(1)
+ for {
+ newSlot := apmSlotAt(newSlots, j)
+ if newSlot.val == nil {
+ newSlot.val = val
+ newSlot.key = oldSlot.key
+ break
+ }
+ j = (j + inc) & newMask
+ inc++
+ }
+ realCount++
+ }
+ // Update dirty to reflect that tombstones were not copied to the new
+ // table. Use realCount since a concurrent mutator may not have updated
+ // shard.count yet.
+ shard.dirty = realCount
+ }
+
+ // Switch to the new table.
+ atomic.StorePointer(&shard.slots, newSlots)
+ atomic.StoreUintptr(&shard.mask, newMask)
+
+ shard.seq.EndWrite()
+ shard.dirtyMu.Unlock()
+}
+
+// Range invokes f on each Key-Value pair stored in m. If any call to f returns
+// false, Range stops iteration and returns.
+//
+// Range does not necessarily correspond to any consistent snapshot of the
+// Map's contents: no Key will be visited more than once, but if the Value for
+// any Key is stored or deleted concurrently, Range may reflect any mapping for
+// that Key from any point during the Range call.
+//
+// f must not call other methods on m.
+func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+ if !shard.doRange(f) {
+ return
+ }
+ }
+}
+
+func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool {
+ // We have to lock rehashMu because if we handled races with rehashing by
+ // retrying, f could see the same key twice.
+ shard.rehashMu.Lock()
+ defer shard.rehashMu.Unlock()
+ slots := shard.slots
+ if slots == nil {
+ return true
+ }
+ mask := shard.mask
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return false
+ }
+ }
+ return true
+}
+
+// RangeRepeatable is like Range, but:
+//
+// * RangeRepeatable may visit the same Key multiple times in the presence of
+// concurrent mutators, possibly passing different Values to f in different
+// calls.
+//
+// * It is safe for f to call other methods on m.
+func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) {
+ for si := 0; si < len(m.shards); si++ {
+ shard := &m.shards[si]
+
+ retry:
+ epoch := shard.seq.BeginRead()
+ slots := atomic.LoadPointer(&shard.slots)
+ mask := atomic.LoadUintptr(&shard.mask)
+ if !shard.seq.ReadOk(epoch) {
+ goto retry
+ }
+ if slots == nil {
+ continue
+ }
+
+ for i := uintptr(0); i <= mask; i++ {
+ slot := apmSlotAt(slots, i)
+ slotVal := atomic.LoadPointer(&slot.val)
+ if slotVal == evacuated() {
+ goto retry
+ }
+ if slotVal == nil || slotVal == tombstone() {
+ continue
+ }
+ if !f(slot.key, (*Value)(slotVal)) {
+ return
+ }
+ }
+ }
+}