// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package atomicptrmap doesn't exist. This file must be instantiated using the
// go_template_instance rule in tools/go_generics/defs.bzl.
package atomicptrmap

import (
	"reflect"
	"runtime"
	"sync/atomic"
	"unsafe"

	"gvisor.dev/gvisor/pkg/gohacks"
	"gvisor.dev/gvisor/pkg/sync"
)

// Key is a required type parameter.
type Key struct{}

// Value is a required type parameter.
type Value struct{}

const (
	// ShardOrder is an optional parameter specifying the base-2 log of the
	// number of shards per AtomicPtrMap. Higher values of ShardOrder reduce
	// unnecessary synchronization between unrelated concurrent operations,
	// improving performance for write-heavy workloads, but increase memory
	// usage for small maps.
	ShardOrder = 0
)

// Hasher is an optional type parameter. If Hasher is provided, it must define
// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps.
type Hasher struct {
	defaultHasher
}

// defaultHasher is the default Hasher. This indirection exists because
// defaultHasher must exist even if a custom Hasher is provided, to prevent the
// Go compiler from complaining about defaultHasher's unused imports.
type defaultHasher struct {
	fn   func(unsafe.Pointer, uintptr) uintptr
	seed uintptr
}

// Init initializes the Hasher.
func (h *defaultHasher) Init() {
	h.fn = sync.MapKeyHasher(map[Key]*Value(nil))
	h.seed = sync.RandUintptr()
}

// Hash returns the hash value for the given Key.
func (h *defaultHasher) Hash(key Key) uintptr {
	return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed)
}

var hasher Hasher

func init() {
	hasher.Init()
}

// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are
// safe for concurrent use from multiple goroutines without additional
// synchronization.
//
// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for
// use. AtomicPtrMaps must not be copied after first use.
//
// sync.Map may be faster than AtomicPtrMap if most operations on the map are
// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in
// other circumstances.
type AtomicPtrMap struct {
	// AtomicPtrMap is implemented as a hash table with the following
	// properties:
	//
	// * Collisions are resolved with quadratic probing. Of the two major
	// alternatives, Robin Hood linear probing makes it difficult for writers
	// to execute in parallel, and bucketing is less effective in Go due to
	// lack of SIMD.
	//
	// * The table is optionally divided into shards indexed by hash to further
	// reduce unnecessary synchronization.

	shards [1 << ShardOrder]apmShard
}

func (m *AtomicPtrMap) shard(hash uintptr) *apmShard {
	// Go defines right shifts >= width of shifted unsigned operand as 0, so
	// this is correct even if ShardOrder is 0 (although nogo complains because
	// nogo is dumb).
	const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder
	index := hash >> indexLSB
	return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{}))))
}

type apmShard struct {
	apmShardMutationData
	_ [apmShardMutationDataPadding]byte
	apmShardLookupData
	_ [apmShardLookupDataPadding]byte
}

type apmShardMutationData struct {
	dirtyMu  sync.Mutex // serializes slot transitions out of empty
	dirty    uintptr    // # slots with val != nil
	count    uintptr    // # slots with val != nil and val != tombstone()
	rehashMu sync.Mutex // serializes rehashing
}

type apmShardLookupData struct {
	seq   sync.SeqCount  // allows atomic reads of slots+mask
	slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq
	mask  uintptr        // always (a power of 2) - 1; protected by rehashMu/seq
}

const (
	cacheLineBytes = 64
	// Cache line padding is enabled if sharding is.
	apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise
	// The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) %
	// cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes).
	apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1)
	apmShardMutationDataPadding         = apmEnablePadding * apmShardMutationDataRequiredPadding
	apmShardLookupDataRequiredPadding   = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1)
	apmShardLookupDataPadding           = apmEnablePadding * apmShardLookupDataRequiredPadding

	// These define fractional thresholds for when apmShard.rehash() is called
	// (i.e. the load factor) and when it rehases to a larger table
	// respectively. They are chosen such that the rehash threshold = the
	// expansion threshold + 1/2, so that when reuse of deleted slots is rare
	// or non-existent, rehashing occurs after the insertion of at least 1/2
	// the table's size in new entries, which is acceptably infrequent.
	apmRehashThresholdNum    = 2
	apmRehashThresholdDen    = 3
	apmExpansionThresholdNum = 1
	apmExpansionThresholdDen = 6
)

type apmSlot struct {
	// slot states are indicated by val:
	//
	// * Empty: val == nil; key is meaningless. May transition to full or
	// evacuated with dirtyMu locked.
	//
	// * Full: val != nil, tombstone(), or evacuated(); key is immutable. val
	// is the Value mapped to key. May transition to deleted or evacuated.
	//
	// * Deleted: val == tombstone(); key is still immutable. key is mapped to
	// no Value. May transition to full or evacuated.
	//
	// * Evacuated: val == evacuated(); key is immutable. Set by rehashing on
	// slots that have already been moved, requiring readers to wait for
	// rehashing to complete and use the new table. Terminal state.
	//
	// Note that once val is non-nil, it cannot become nil again. That is, the
	// transition from empty to non-empty is irreversible for a given slot;
	// the only way to create more empty slots is by rehashing.
	val unsafe.Pointer
	key Key
}

func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot {
	return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{})))
}

var tombstoneObj byte

func tombstone() unsafe.Pointer {
	return unsafe.Pointer(&tombstoneObj)
}

var evacuatedObj byte

func evacuated() unsafe.Pointer {
	return unsafe.Pointer(&evacuatedObj)
}

// Load returns the Value stored in m for key.
func (m *AtomicPtrMap) Load(key Key) *Value {
	hash := hasher.Hash(key)
	shard := m.shard(hash)

retry:
	epoch := shard.seq.BeginRead()
	slots := atomic.LoadPointer(&shard.slots)
	mask := atomic.LoadUintptr(&shard.mask)
	if !shard.seq.ReadOk(epoch) {
		goto retry
	}
	if slots == nil {
		return nil
	}

	i := hash & mask
	inc := uintptr(1)
	for {
		slot := apmSlotAt(slots, i)
		slotVal := atomic.LoadPointer(&slot.val)
		if slotVal == nil {
			// Empty slot; end of probe sequence.
			return nil
		}
		if slotVal == evacuated() {
			// Racing with rehashing.
			goto retry
		}
		if slot.key == key {
			if slotVal == tombstone() {
				return nil
			}
			return (*Value)(slotVal)
		}
		i = (i + inc) & mask
		inc++
	}
}

// Store stores the Value val for key.
func (m *AtomicPtrMap) Store(key Key, val *Value) {
	m.maybeCompareAndSwap(key, false, nil, val)
}

// Swap stores the Value val for key and returns the previously-mapped Value.
func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value {
	return m.maybeCompareAndSwap(key, false, nil, val)
}

// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it
// stores the Value newVal for key. CompareAndSwap returns the previous Value
// stored for key, whether or not it stores newVal.
func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value {
	return m.maybeCompareAndSwap(key, true, oldVal, newVal)
}

func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value {
	hash := hasher.Hash(key)
	shard := m.shard(hash)
	oldVal := tombstone()
	if typedOldVal != nil {
		oldVal = unsafe.Pointer(typedOldVal)
	}
	newVal := tombstone()
	if typedNewVal != nil {
		newVal = unsafe.Pointer(typedNewVal)
	}

retry:
	epoch := shard.seq.BeginRead()
	slots := atomic.LoadPointer(&shard.slots)
	mask := atomic.LoadUintptr(&shard.mask)
	if !shard.seq.ReadOk(epoch) {
		goto retry
	}
	if slots == nil {
		if (compare && oldVal != tombstone()) || newVal == tombstone() {
			return nil
		}
		// Need to allocate a table before insertion.
		shard.rehash(nil)
		goto retry
	}

	i := hash & mask
	inc := uintptr(1)
	for {
		slot := apmSlotAt(slots, i)
		slotVal := atomic.LoadPointer(&slot.val)
		if slotVal == nil {
			if (compare && oldVal != tombstone()) || newVal == tombstone() {
				return nil
			}
			// Try to grab this slot for ourselves.
			shard.dirtyMu.Lock()
			slotVal = atomic.LoadPointer(&slot.val)
			if slotVal == nil {
				// Check if we need to rehash before dirtying a slot.
				if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum {
					shard.dirtyMu.Unlock()
					shard.rehash(slots)
					goto retry
				}
				slot.key = key
				atomic.StorePointer(&slot.val, newVal) // transitions slot to full
				shard.dirty++
				atomic.AddUintptr(&shard.count, 1)
				shard.dirtyMu.Unlock()
				return nil
			}
			// Raced with another store; the slot is no longer empty. Continue
			// with the new value of slotVal since we may have raced with
			// another store of key.
			shard.dirtyMu.Unlock()
		}
		if slotVal == evacuated() {
			// Racing with rehashing.
			goto retry
		}
		if slot.key == key {
			// We're reusing an existing slot, so rehashing isn't necessary.
			for {
				if (compare && oldVal != slotVal) || newVal == slotVal {
					if slotVal == tombstone() {
						return nil
					}
					return (*Value)(slotVal)
				}
				if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) {
					if slotVal == tombstone() {
						atomic.AddUintptr(&shard.count, 1)
						return nil
					}
					if newVal == tombstone() {
						atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */)
					}
					return (*Value)(slotVal)
				}
				slotVal = atomic.LoadPointer(&slot.val)
				if slotVal == evacuated() {
					goto retry
				}
			}
		}
		// This produces a triangular number sequence of offsets from the
		// initially-probed position.
		i = (i + inc) & mask
		inc++
	}
}

// rehash is marked nosplit to avoid preemption during table copying.
//go:nosplit
func (shard *apmShard) rehash(oldSlots unsafe.Pointer) {
	shard.rehashMu.Lock()
	defer shard.rehashMu.Unlock()

	if shard.slots != oldSlots {
		// Raced with another call to rehash().
		return
	}

	// Determine the size of the new table. Constraints:
	//
	// * The size of the table must be a power of two to ensure that every slot
	// is visitable by every probe sequence under quadratic probing with
	// triangular numbers.
	//
	// * The size of the table cannot decrease because even if shard.count is
	// currently smaller than shard.dirty, concurrent stores that reuse
	// existing slots can drive shard.count back up to a maximum of
	// shard.dirty.
	newSize := uintptr(8) // arbitrary initial size
	if oldSlots != nil {
		oldSize := shard.mask + 1
		newSize = oldSize
		if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum {
			newSize *= 2
		}
	}

	// Allocate the new table.
	newSlotsSlice := make([]apmSlot, newSize)
	newSlotsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&newSlotsSlice))
	newSlots := unsafe.Pointer(newSlotsReflect.Data)
	runtime.KeepAlive(newSlotsSlice)
	newMask := newSize - 1

	// Start a writer critical section now so that racing users of the old
	// table that observe evacuated() wait for the new table. (But lock dirtyMu
	// first since doing so may block, which we don't want to do during the
	// writer critical section.)
	shard.dirtyMu.Lock()
	shard.seq.BeginWrite()

	if oldSlots != nil {
		realCount := uintptr(0)
		// Copy old entries to the new table.
		oldMask := shard.mask
		for i := uintptr(0); i <= oldMask; i++ {
			oldSlot := apmSlotAt(oldSlots, i)
			val := atomic.SwapPointer(&oldSlot.val, evacuated())
			if val == nil || val == tombstone() {
				continue
			}
			hash := hasher.Hash(oldSlot.key)
			j := hash & newMask
			inc := uintptr(1)
			for {
				newSlot := apmSlotAt(newSlots, j)
				if newSlot.val == nil {
					newSlot.val = val
					newSlot.key = oldSlot.key
					break
				}
				j = (j + inc) & newMask
				inc++
			}
			realCount++
		}
		// Update dirty to reflect that tombstones were not copied to the new
		// table. Use realCount since a concurrent mutator may not have updated
		// shard.count yet.
		shard.dirty = realCount
	}

	// Switch to the new table.
	atomic.StorePointer(&shard.slots, newSlots)
	atomic.StoreUintptr(&shard.mask, newMask)

	shard.seq.EndWrite()
	shard.dirtyMu.Unlock()
}

// Range invokes f on each Key-Value pair stored in m. If any call to f returns
// false, Range stops iteration and returns.
//
// Range does not necessarily correspond to any consistent snapshot of the
// Map's contents: no Key will be visited more than once, but if the Value for
// any Key is stored or deleted concurrently, Range may reflect any mapping for
// that Key from any point during the Range call.
//
// f must not call other methods on m.
func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) {
	for si := 0; si < len(m.shards); si++ {
		shard := &m.shards[si]
		if !shard.doRange(f) {
			return
		}
	}
}

func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool {
	// We have to lock rehashMu because if we handled races with rehashing by
	// retrying, f could see the same key twice.
	shard.rehashMu.Lock()
	defer shard.rehashMu.Unlock()
	slots := shard.slots
	if slots == nil {
		return true
	}
	mask := shard.mask
	for i := uintptr(0); i <= mask; i++ {
		slot := apmSlotAt(slots, i)
		slotVal := atomic.LoadPointer(&slot.val)
		if slotVal == nil || slotVal == tombstone() {
			continue
		}
		if !f(slot.key, (*Value)(slotVal)) {
			return false
		}
	}
	return true
}

// RangeRepeatable is like Range, but:
//
// * RangeRepeatable may visit the same Key multiple times in the presence of
// concurrent mutators, possibly passing different Values to f in different
// calls.
//
// * It is safe for f to call other methods on m.
func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) {
	for si := 0; si < len(m.shards); si++ {
		shard := &m.shards[si]

	retry:
		epoch := shard.seq.BeginRead()
		slots := atomic.LoadPointer(&shard.slots)
		mask := atomic.LoadUintptr(&shard.mask)
		if !shard.seq.ReadOk(epoch) {
			goto retry
		}
		if slots == nil {
			continue
		}

		for i := uintptr(0); i <= mask; i++ {
			slot := apmSlotAt(slots, i)
			slotVal := atomic.LoadPointer(&slot.val)
			if slotVal == evacuated() {
				goto retry
			}
			if slotVal == nil || slotVal == tombstone() {
				continue
			}
			if !f(slot.key, (*Value)(slotVal)) {
				return
			}
		}
	}
}