// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package atomicptrmap import ( "context" "fmt" "math/rand" "reflect" "runtime" "testing" "time" "gvisor.dev/gvisor/pkg/sync" ) func TestConsistencyWithGoMap(t *testing.T) { const maxKey = 16 var vals [4]*testValue for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { vals[i] = new(testValue) } var ( m = make(map[int64]*testValue) apm testAtomicPtrMap ) for i := 0; i < 100000; i++ { // Apply a random operation to both m and apm and expect them to have // the same result. Bias toward CompareAndSwap, which has the most // cases; bias away from Range and RangeRepeatable, which are // relatively expensive. switch rand.Intn(10) { case 0, 1: // Load key := rand.Int63n(maxKey) want := m[key] got := apm.Load(key) t.Logf("Load(%d) = %p", key, got) if got != want { t.Fatalf("got %p, wanted %p", got, want) } case 2, 3: // Swap key := rand.Int63n(maxKey) val := vals[rand.Intn(len(vals))] want := m[key] if val != nil { m[key] = val } else { delete(m, key) } got := apm.Swap(key, val) t.Logf("Swap(%d, %p) = %p", key, val, got) if got != want { t.Fatalf("got %p, wanted %p", got, want) } case 4, 5, 6, 7: // CompareAndSwap key := rand.Int63n(maxKey) oldVal := vals[rand.Intn(len(vals))] newVal := vals[rand.Intn(len(vals))] want := m[key] if want == oldVal { if newVal != nil { m[key] = newVal } else { delete(m, key) } } got := apm.CompareAndSwap(key, oldVal, newVal) t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got) if got != want { t.Fatalf("got %p, wanted %p", got, want) } case 8: // Range got := make(map[int64]*testValue) var ( haveDup = false dup int64 ) apm.Range(func(key int64, val *testValue) bool { if _, ok := got[key]; ok && !haveDup { haveDup = true dup = key } got[key] = val return true }) t.Logf("Range() = %v", got) if !reflect.DeepEqual(got, m) { t.Fatalf("got %v, wanted %v", got, m) } if haveDup { t.Fatalf("got duplicate key %d", dup) } case 9: // RangeRepeatable got := make(map[int64]*testValue) apm.RangeRepeatable(func(key int64, val *testValue) bool { got[key] = val return true }) t.Logf("RangeRepeatable() = %v", got) if !reflect.DeepEqual(got, m) { t.Fatalf("got %v, wanted %v", got, m) } } } } func TestConcurrentHeterogeneous(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var ( apm testAtomicPtrMap wg sync.WaitGroup ) defer func() { cancel() wg.Wait() }() possibleKeyValuePairs := make(map[int64]map[*testValue]struct{}) addKeyValuePair := func(key int64, val *testValue) { values := possibleKeyValuePairs[key] if values == nil { values = make(map[*testValue]struct{}) possibleKeyValuePairs[key] = values } values[val] = struct{}{} } const numValuesPerKey = 4 // These goroutines use keys not used by any other goroutine. const numPrivateKeys = 3 for i := 0; i < numPrivateKeys; i++ { key := int64(i) var vals [numValuesPerKey]*testValue for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { val := new(testValue) vals[i] = val addKeyValuePair(key, val) } wg.Add(1) go func() { defer wg.Done() r := rand.New(rand.NewSource(rand.Int63())) var stored *testValue for ctx.Err() == nil { switch r.Intn(4) { case 0: got := apm.Load(key) if got != stored { t.Errorf("Load(%d): got %p, wanted %p", key, got, stored) return } case 1: val := vals[r.Intn(len(vals))] want := stored stored = val got := apm.Swap(key, val) if got != want { t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want) return } case 2, 3: oldVal := vals[r.Intn(len(vals))] newVal := vals[r.Intn(len(vals))] want := stored if stored == oldVal { stored = newVal } got := apm.CompareAndSwap(key, oldVal, newVal) if got != want { t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want) return } } } }() } // These goroutines share a small set of keys. const numSharedKeys = 2 var ( sharedKeys [numSharedKeys]int64 sharedValues = make(map[int64][]*testValue) sharedValuesSet = make(map[int64]map[*testValue]struct{}) ) for i := range sharedKeys { key := int64(numPrivateKeys + i) sharedKeys[i] = key vals := make([]*testValue, numValuesPerKey) valsSet := make(map[*testValue]struct{}) for j := range vals { val := new(testValue) vals[j] = val valsSet[val] = struct{}{} addKeyValuePair(key, val) } sharedValues[key] = vals sharedValuesSet[key] = valsSet } randSharedValue := func(r *rand.Rand, key int64) *testValue { vals := sharedValues[key] return vals[r.Intn(len(vals))] } for i := 0; i < 3; i++ { wg.Add(1) go func() { defer wg.Done() r := rand.New(rand.NewSource(rand.Int63())) for ctx.Err() == nil { keyIndex := r.Intn(len(sharedKeys)) key := sharedKeys[keyIndex] var ( op string got *testValue ) switch r.Intn(4) { case 0: op = "Load" got = apm.Load(key) case 1: op = "Swap" got = apm.Swap(key, randSharedValue(r, key)) case 2, 3: op = "CompareAndSwap" got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key)) } if got != nil { valsSet := sharedValuesSet[key] if _, ok := valsSet[got]; !ok { t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet) return } } } }() } // This goroutine repeatedly searches for unused keys. wg.Add(1) go func() { defer wg.Done() r := rand.New(rand.NewSource(rand.Int63())) for ctx.Err() == nil { key := -1 - r.Int63() if got := apm.Load(key); got != nil { t.Errorf("Load(%d): got %p, wanted nil", key, got) } } }() // This goroutine repeatedly calls RangeRepeatable() and checks that each // key corresponds to an expected value. wg.Add(1) go func() { defer wg.Done() abort := false for !abort && ctx.Err() == nil { apm.RangeRepeatable(func(key int64, val *testValue) bool { values, ok := possibleKeyValuePairs[key] if !ok { t.Errorf("RangeRepeatable: got invalid key %d", key) abort = true return false } if _, ok := values[val]; !ok { t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values) abort = true return false } return true }) } }() // Finally, the main goroutine spins for the length of the test calling // Range() and checking that each key that it observes is unique and // corresponds to an expected value. seenKeys := make(map[int64]struct{}) const testDuration = 5 * time.Second end := time.Now().Add(testDuration) abort := false for time.Now().Before(end) { apm.Range(func(key int64, val *testValue) bool { values, ok := possibleKeyValuePairs[key] if !ok { t.Errorf("Range: got invalid key %d", key) abort = true return false } if _, ok := values[val]; !ok { t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values) abort = true return false } if _, ok := seenKeys[key]; ok { t.Errorf("Range: got duplicate key %d", key) abort = true return false } seenKeys[key] = struct{}{} return true }) if abort { break } for k := range seenKeys { delete(seenKeys, k) } } } type benchmarkableMap interface { Load(key int64) *testValue Store(key int64, val *testValue) LoadOrStore(key int64, val *testValue) (*testValue, bool) Delete(key int64) } // rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map. type rwMutexMap struct { mu sync.RWMutex m map[int64]*testValue } func (m *rwMutexMap) Load(key int64) *testValue { m.mu.RLock() defer m.mu.RUnlock() return m.m[key] } func (m *rwMutexMap) Store(key int64, val *testValue) { m.mu.Lock() defer m.mu.Unlock() if m.m == nil { m.m = make(map[int64]*testValue) } m.m[key] = val } func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { m.mu.Lock() defer m.mu.Unlock() if m.m == nil { m.m = make(map[int64]*testValue) } if oldVal, ok := m.m[key]; ok { return oldVal, true } m.m[key] = val return val, false } func (m *rwMutexMap) Delete(key int64) { m.mu.Lock() defer m.mu.Unlock() delete(m.m, key) } // syncMap implements benchmarkableMap for a sync.Map. type syncMap struct { m sync.Map } func (m *syncMap) Load(key int64) *testValue { val, ok := m.m.Load(key) if !ok { return nil } return val.(*testValue) } func (m *syncMap) Store(key int64, val *testValue) { m.m.Store(key, val) } func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { actual, loaded := m.m.LoadOrStore(key, val) return actual.(*testValue), loaded } func (m *syncMap) Delete(key int64) { m.m.Delete(key) } // benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap. type benchmarkableAtomicPtrMap struct { m testAtomicPtrMap } func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue { return m.m.Load(key) } func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) { m.m.Store(key, val) } func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { return prev, true } return val, false } func (m *benchmarkableAtomicPtrMap) Delete(key int64) { m.m.Store(key, nil) } // benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded. type benchmarkableAtomicPtrMapSharded struct { m testAtomicPtrMapSharded } func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue { return m.m.Load(key) } func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) { m.m.Store(key, val) } func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) { if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { return prev, true } return val, false } func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) { m.m.Store(key, nil) } var mapImpls = [...]struct { name string ctor func() benchmarkableMap }{ { name: "RWMutexMap", ctor: func() benchmarkableMap { return new(rwMutexMap) }, }, { name: "SyncMap", ctor: func() benchmarkableMap { return new(syncMap) }, }, { name: "AtomicPtrMap", ctor: func() benchmarkableMap { return new(benchmarkableAtomicPtrMap) }, }, { name: "AtomicPtrMapSharded", ctor: func() benchmarkableMap { return new(benchmarkableAtomicPtrMapSharded) }, }, } func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { m := mapCtor() val := &testValue{} for i := 0; i < b.N; i++ { m.Store(int64(i), val) } for i := 0; i < b.N; i++ { m.Delete(int64(i)) } } func BenchmarkStoreDelete(b *testing.B) { for _, mapImpl := range mapImpls { b.Run(mapImpl.name, func(b *testing.B) { benchmarkStoreDelete(b, mapImpl.ctor) }) } } func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { m := mapCtor() val := &testValue{} for i := 0; i < b.N; i++ { m.LoadOrStore(int64(i), val) } for i := 0; i < b.N; i++ { m.Delete(int64(i)) } } func BenchmarkLoadOrStoreDelete(b *testing.B) { for _, mapImpl := range mapImpls { b.Run(mapImpl.name, func(b *testing.B) { benchmarkLoadOrStoreDelete(b, mapImpl.ctor) }) } } func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) { m := mapCtor() val := &testValue{} for i := 0; i < b.N; i++ { m.Store(int64(i), val) } b.ResetTimer() for i := 0; i < b.N; i++ { m.Load(int64(i)) } } func BenchmarkLookupPositive(b *testing.B) { for _, mapImpl := range mapImpls { b.Run(mapImpl.name, func(b *testing.B) { benchmarkLookupPositive(b, mapImpl.ctor) }) } } func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) { m := mapCtor() val := &testValue{} for i := 0; i < b.N; i++ { m.Store(int64(i), val) } b.ResetTimer() for i := 0; i < b.N; i++ { m.Load(int64(-1 - i)) } } func BenchmarkLookupNegative(b *testing.B) { for _, mapImpl := range mapImpls { b.Run(mapImpl.name, func(b *testing.B) { benchmarkLookupNegative(b, mapImpl.ctor) }) } } type benchmarkConcurrentOptions struct { // loadsPerMutationPair is the number of map lookups between each // insertion/deletion pair. loadsPerMutationPair int // If changeKeys is true, the keys used by each goroutine change between // iterations of the test. changeKeys bool } func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) { var ( started sync.WaitGroup workers sync.WaitGroup ) started.Add(1) m := mapCtor() val := &testValue{} // Insert a large number of unused elements into the map so that used // elements are distributed throughout memory. for i := 0; i < 10000; i++ { m.Store(int64(-1-i), val) } // n := ceil(b.N / (opts.loadsPerMutationPair + 2)) n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2) for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ { workerID := i workers.Add(1) go func() { defer workers.Done() started.Wait() for i := 0; i < n; i++ { var key int64 if opts.changeKeys { key = int64(workerID*n + i) } else { key = int64(workerID) } m.LoadOrStore(key, val) for j := 0; j < opts.loadsPerMutationPair; j++ { m.Load(key) } m.Delete(key) } }() } b.ResetTimer() started.Done() workers.Wait() } func BenchmarkConcurrent(b *testing.B) { changeKeysChoices := [...]struct { name string val bool }{ {"FixedKeys", false}, {"ChangingKeys", true}, } writePcts := [...]struct { name string loadsPerMutationPair int }{ {"1PercentWrites", 198}, {"10PercentWrites", 18}, {"50PercentWrites", 2}, } for _, changeKeys := range changeKeysChoices { for _, writePct := range writePcts { for _, mapImpl := range mapImpls { name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name) b.Run(name, func(b *testing.B) { benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{ loadsPerMutationPair: writePct.loadsPerMutationPair, changeKeys: changeKeys.val, }) }) } } } }