summaryrefslogtreecommitdiffhomepage
path: root/pkg/state
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state')
-rw-r--r--pkg/state/BUILD86
-rw-r--r--pkg/state/README.md158
-rw-r--r--pkg/state/addr_range.go76
-rw-r--r--pkg/state/addr_set.go1643
-rw-r--r--pkg/state/complete_list.go221
-rw-r--r--pkg/state/deferred_list.go206
-rw-r--r--pkg/state/pretty/BUILD13
-rw-r--r--pkg/state/pretty/pretty_state_autogen.go3
-rw-r--r--pkg/state/statefile/BUILD22
-rw-r--r--pkg/state/statefile/statefile_state_autogen.go3
-rw-r--r--pkg/state/statefile/statefile_test.go290
-rw-r--r--pkg/state/tests/BUILD43
-rw-r--r--pkg/state/tests/array.go35
-rw-r--r--pkg/state/tests/array_test.go134
-rw-r--r--pkg/state/tests/bench.go24
-rw-r--r--pkg/state/tests/bench_test.go153
-rw-r--r--pkg/state/tests/bool_test.go31
-rw-r--r--pkg/state/tests/float_test.go118
-rw-r--r--pkg/state/tests/integer.go163
-rw-r--r--pkg/state/tests/integer_test.go94
-rw-r--r--pkg/state/tests/load.go61
-rw-r--r--pkg/state/tests/load_test.go78
-rw-r--r--pkg/state/tests/map.go28
-rw-r--r--pkg/state/tests/map_test.go90
-rw-r--r--pkg/state/tests/register.go21
-rw-r--r--pkg/state/tests/register_test.go178
-rw-r--r--pkg/state/tests/string_test.go34
-rw-r--r--pkg/state/tests/struct.go100
-rw-r--r--pkg/state/tests/struct_test.go100
-rw-r--r--pkg/state/tests/tests.go215
-rw-r--r--pkg/state/wire/BUILD12
31 files changed, 2152 insertions, 2281 deletions
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
deleted file mode 100644
index 92c51879b..000000000
--- a/pkg/state/BUILD
+++ /dev/null
@@ -1,86 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "deferred_list",
- out = "deferred_list.go",
- package = "state",
- prefix = "deferred",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*objectEncodeState",
- "ElementMapper": "deferredMapper",
- "Linker": "*deferredEntry",
- },
-)
-
-go_template_instance(
- name = "complete_list",
- out = "complete_list.go",
- package = "state",
- prefix = "complete",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*objectDecodeState",
- "Linker": "*objectDecodeState",
- },
-)
-
-go_template_instance(
- name = "addr_range",
- out = "addr_range.go",
- package = "state",
- prefix = "addr",
- template = "//pkg/segment:generic_range",
- types = {
- "T": "uintptr",
- },
-)
-
-go_template_instance(
- name = "addr_set",
- out = "addr_set.go",
- consts = {
- "minDegree": "10",
- },
- imports = {
- "reflect": "reflect",
- },
- package = "state",
- prefix = "addr",
- template = "//pkg/segment:generic_set",
- types = {
- "Key": "uintptr",
- "Range": "addrRange",
- "Value": "*objectEncodeState",
- "Functions": "addrSetFunctions",
- },
-)
-
-go_library(
- name = "state",
- srcs = [
- "addr_range.go",
- "addr_set.go",
- "complete_list.go",
- "decode.go",
- "decode_unsafe.go",
- "deferred_list.go",
- "encode.go",
- "encode_unsafe.go",
- "state.go",
- "state_norace.go",
- "state_race.go",
- "stats.go",
- "types.go",
- ],
- marshal = False,
- stateify = False,
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/log",
- "//pkg/state/wire",
- ],
-)
diff --git a/pkg/state/README.md b/pkg/state/README.md
deleted file mode 100644
index 1aa401193..000000000
--- a/pkg/state/README.md
+++ /dev/null
@@ -1,158 +0,0 @@
-# State Encoding and Decoding
-
-The state package implements the encoding and decoding of data structures for
-`go_stateify`. This package is designed for use cases other than the standard
-encoding packages, e.g. `gob` and `json`. Principally:
-
-* This package operates on complex object graphs and accurately serializes and
- restores all relationships. That is, you can have things like: intrusive
- pointers, cycles, and pointer chains of arbitrary depths. These are not
- handled appropriately by existing encoders. This is not an implementation
- flaw: the formats themselves are not capable of representing these graphs,
- as they can only generate directed trees.
-
-* This package allows installing order-dependent load callbacks and then
- resolves that graph at load time, with cycle detection. Similarly, there is
- no analogous feature possible in the standard encoders.
-
-* This package handles the resolution of interfaces, based on a registered
- type name. For interface objects type information is saved in the serialized
- format. This is generally true for `gob` as well, but it works differently.
-
-Here's an overview of how encoding and decoding works.
-
-## Encoding
-
-Encoding produces a `statefile`, which contains a list of chunks of the form
-`(header, payload)`. The payload can either be some raw data, or a series of
-encoded wire objects representing some object graph. All encoded objects are
-defined in the `wire` subpackage.
-
-Encoding of an object graph begins with `encodeState.Save`.
-
-### 1. Memory Map & Encoding
-
-To discover relationships between potentially interdependent data structures
-(for example, a struct may contain pointers to members of other data
-structures), the encoder first walks the object graph and constructs a memory
-map of the objects in the input graph. As this walk progresses, objects are
-queued in the `pending` list and items are placed on the `deferred` list as they
-are discovered. No single object will be encoded multiple times, but the
-discovered relationships between objects may change as more parts of the overall
-object graph are discovered.
-
-The encoder starts at the root object and recursively visits all reachable
-objects, recording the address ranges containing the underlying data for each
-object. This is stored as a segment set (`addrSet`), mapping address ranges to
-the of the object occupying the range; see `encodeState.values`. Note that there
-is special handling for zero-sized types and map objects during this process.
-
-Additionally, the encoder assigns each object a unique identifier which is used
-to indicate relationships between objects in the statefile; see `objectID` in
-`encode.go`.
-
-### 2. Type Serialization
-
-The enoder will subsequently serialize all information about discovered types,
-including field names. These are used during decoding to reconcile these types
-with other internally registered types.
-
-### 3. Object Serialization
-
-With a full address map, and all objects correctly encoded, all object encodings
-are serialized. The assigned `objectID`s aren't explicitly encoded in the
-statefile. The order of object messages in the stream determine their IDs.
-
-### Example
-
-Given the following data structure definitions:
-
-```go
-type system struct {
- o *outer
- i *inner
-}
-
-type outer struct {
- a int64
- cn *container
-}
-
-type container struct {
- n uint64
- elem *inner
-}
-
-type inner struct {
- c container
- x, y uint64
-}
-```
-
-Initialized like this:
-
-```go
-o := outer{
- a: 10,
- cn: nil,
-}
-i := inner{
- x: 20,
- y: 30,
- c: container{},
-}
-s := system{
- o: &o,
- i: &i,
-}
-
-o.cn = &i.c
-o.cn.elem = &i
-
-```
-
-Encoding will produce an object stream like this:
-
-```
-g0r1 = struct{
- i: g0r3,
- o: g0r2,
-}
-g0r2 = struct{
- a: 10,
- cn: g0r3.c,
-}
-g0r3 = struct{
- c: struct{
- elem: g0r3,
- n: 0u,
- },
- x: 20u,
- y: 30u,
-}
-```
-
-Note how `g0r3.c` is correctly encoded as the underlying `container` object for
-`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i`
-being discovered after the pointer to it in `system.o.cn`. Also note that
-decoding isn't strictly reliant on the order of encoded object stream, as long
-as the relationship between objects are correctly encoded.
-
-## Decoding
-
-Decoding reads the statefile and reconstructs the object graph. Decoding begins
-in `decodeState.Load`. Decoding is performed in a single pass over the object
-stream in the statefile, and a subsequent pass over all deserialized objects is
-done to fire off all loading callbacks in the correctly defined order. Note that
-introducing cycles is possible here, but these are detected and an error will be
-returned.
-
-Decoding is relatively straight forward. For most primitive values, the decoder
-constructs an appropriate object and fills it with the values encoded in the
-statefile. Pointers need special handling, as they must point to a value
-allocated elsewhere. When values are constructed, the decoder indexes them by
-their `objectID`s in `decodeState.objectsByID`. The target of pointers are
-resolved by searching for the target in this index by their `objectID`; see
-`decodeState.register`. For pointers to values inside another value (fields in a
-pointer, elements of an array), the decoder uses the accessor path to walk to
-the appropriate location; see `walkChild`.
diff --git a/pkg/state/addr_range.go b/pkg/state/addr_range.go
new file mode 100644
index 000000000..0b7346e47
--- /dev/null
+++ b/pkg/state/addr_range.go
@@ -0,0 +1,76 @@
+package state
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type addrRange struct {
+ // Start is the inclusive start of the range.
+ Start uintptr
+
+ // End is the exclusive end of the range.
+ End uintptr
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+//
+//go:nosplit
+func (r addrRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+//
+//go:nosplit
+func (r addrRange) Length() uintptr {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+//
+//go:nosplit
+func (r addrRange) Contains(x uintptr) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+//
+//go:nosplit
+func (r addrRange) Overlaps(r2 addrRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+//
+//go:nosplit
+func (r addrRange) IsSupersetOf(r2 addrRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+//
+//go:nosplit
+func (r addrRange) Intersect(r2 addrRange) addrRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+//
+//go:nosplit
+func (r addrRange) CanSplitAt(x uintptr) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/state/addr_set.go b/pkg/state/addr_set.go
new file mode 100644
index 000000000..591af5292
--- /dev/null
+++ b/pkg/state/addr_set.go
@@ -0,0 +1,1643 @@
+package state
+
+import (
+ "bytes"
+ "fmt"
+)
+
+// trackGaps is an optional parameter.
+//
+// If trackGaps is 1, the Set will track maximum gap size recursively,
+// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this
+// case, Key must be an unsigned integer.
+//
+// trackGaps must be 0 or 1.
+const addrtrackGaps = 0
+
+var _ = uint8(addrtrackGaps << 7) // Will fail if not zero or one.
+
+// dynamicGap is a type that disappears if trackGaps is 0.
+type addrdynamicGap [addrtrackGaps]uintptr
+
+// Get returns the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *addrdynamicGap) Get() uintptr {
+ return d[:][0]
+}
+
+// Set sets the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *addrdynamicGap) Set(v uintptr) {
+ d[:][0] = v
+}
+
+const (
+ // minDegree is the minimum degree of an internal node in a Set B-tree.
+ //
+ // - Any non-root node has at least minDegree-1 segments.
+ //
+ // - Any non-root internal (non-leaf) node has at least minDegree children.
+ //
+ // - The root node may have fewer than minDegree-1 segments, but it may
+ // only have 0 segments if the tree is empty.
+ //
+ // Our implementation requires minDegree >= 3. Higher values of minDegree
+ // usually improve performance, but increase memory usage for small sets.
+ addrminDegree = 10
+
+ addrmaxDegree = 2 * addrminDegree
+)
+
+// A Set is a mapping of segments with non-overlapping Range keys. The zero
+// value for a Set is an empty set. Set values are not safely movable nor
+// copyable. Set is thread-compatible.
+//
+// +stateify savable
+type addrSet struct {
+ root addrnode `state:".(*addrSegmentDataSlices)"`
+}
+
+// IsEmpty returns true if the set contains no segments.
+func (s *addrSet) IsEmpty() bool {
+ return s.root.nrSegments == 0
+}
+
+// IsEmptyRange returns true iff no segments in the set overlap the given
+// range. This is semantically equivalent to s.SpanRange(r) == 0, but may be
+// more efficient.
+func (s *addrSet) IsEmptyRange(r addrRange) bool {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return true
+ }
+ _, gap := s.Find(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ return r.End <= gap.End()
+}
+
+// Span returns the total size of all segments in the set.
+func (s *addrSet) Span() uintptr {
+ var sz uintptr
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sz += seg.Range().Length()
+ }
+ return sz
+}
+
+// SpanRange returns the total size of the intersection of segments in the set
+// with the given range.
+func (s *addrSet) SpanRange(r addrRange) uintptr {
+ switch {
+ case r.Length() < 0:
+ panic(fmt.Sprintf("invalid range %v", r))
+ case r.Length() == 0:
+ return 0
+ }
+ var sz uintptr
+ for seg := s.LowerBoundSegment(r.Start); seg.Ok() && seg.Start() < r.End; seg = seg.NextSegment() {
+ sz += seg.Range().Intersect(r).Length()
+ }
+ return sz
+}
+
+// FirstSegment returns the first segment in the set. If the set is empty,
+// FirstSegment returns a terminal iterator.
+func (s *addrSet) FirstSegment() addrIterator {
+ if s.root.nrSegments == 0 {
+ return addrIterator{}
+ }
+ return s.root.firstSegment()
+}
+
+// LastSegment returns the last segment in the set. If the set is empty,
+// LastSegment returns a terminal iterator.
+func (s *addrSet) LastSegment() addrIterator {
+ if s.root.nrSegments == 0 {
+ return addrIterator{}
+ }
+ return s.root.lastSegment()
+}
+
+// FirstGap returns the first gap in the set.
+func (s *addrSet) FirstGap() addrGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return addrGapIterator{n, 0}
+}
+
+// LastGap returns the last gap in the set.
+func (s *addrSet) LastGap() addrGapIterator {
+ n := &s.root
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return addrGapIterator{n, n.nrSegments}
+}
+
+// Find returns the segment or gap whose range contains the given key. If a
+// segment is found, the returned Iterator is non-terminal and the
+// returned GapIterator is terminal. Otherwise, the returned Iterator is
+// terminal and the returned GapIterator is non-terminal.
+func (s *addrSet) Find(key uintptr) (addrIterator, addrGapIterator) {
+ n := &s.root
+ for {
+
+ lower := 0
+ upper := n.nrSegments
+ for lower < upper {
+ i := lower + (upper-lower)/2
+ if r := n.keys[i]; key < r.End {
+ if key >= r.Start {
+ return addrIterator{n, i}, addrGapIterator{}
+ }
+ upper = i
+ } else {
+ lower = i + 1
+ }
+ }
+ i := lower
+ if !n.hasChildren {
+ return addrIterator{}, addrGapIterator{n, i}
+ }
+ n = n.children[i]
+ }
+}
+
+// FindSegment returns the segment whose range contains the given key. If no
+// such segment exists, FindSegment returns a terminal iterator.
+func (s *addrSet) FindSegment(key uintptr) addrIterator {
+ seg, _ := s.Find(key)
+ return seg
+}
+
+// LowerBoundSegment returns the segment with the lowest range that contains a
+// key greater than or equal to min. If no such segment exists,
+// LowerBoundSegment returns a terminal iterator.
+func (s *addrSet) LowerBoundSegment(min uintptr) addrIterator {
+ seg, gap := s.Find(min)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.NextSegment()
+}
+
+// UpperBoundSegment returns the segment with the highest range that contains a
+// key less than or equal to max. If no such segment exists, UpperBoundSegment
+// returns a terminal iterator.
+func (s *addrSet) UpperBoundSegment(max uintptr) addrIterator {
+ seg, gap := s.Find(max)
+ if seg.Ok() {
+ return seg
+ }
+ return gap.PrevSegment()
+}
+
+// FindGap returns the gap containing the given key. If no such gap exists
+// (i.e. the set contains a segment containing that key), FindGap returns a
+// terminal iterator.
+func (s *addrSet) FindGap(key uintptr) addrGapIterator {
+ _, gap := s.Find(key)
+ return gap
+}
+
+// LowerBoundGap returns the gap with the lowest range that is greater than or
+// equal to min.
+func (s *addrSet) LowerBoundGap(min uintptr) addrGapIterator {
+ seg, gap := s.Find(min)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.NextGap()
+}
+
+// UpperBoundGap returns the gap with the highest range that is less than or
+// equal to max.
+func (s *addrSet) UpperBoundGap(max uintptr) addrGapIterator {
+ seg, gap := s.Find(max)
+ if gap.Ok() {
+ return gap
+ }
+ return seg.PrevGap()
+}
+
+// Add inserts the given segment into the set and returns true. If the new
+// segment can be merged with adjacent segments, Add will do so. If the new
+// segment would overlap an existing segment, Add returns false. If Add
+// succeeds, all existing iterators are invalidated.
+func (s *addrSet) Add(r addrRange, val *objectEncodeState) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.Insert(gap, r, val)
+ return true
+}
+
+// AddWithoutMerging inserts the given segment into the set and returns true.
+// If it would overlap an existing segment, AddWithoutMerging does nothing and
+// returns false. If AddWithoutMerging succeeds, all existing iterators are
+// invalidated.
+func (s *addrSet) AddWithoutMerging(r addrRange, val *objectEncodeState) bool {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ gap := s.FindGap(r.Start)
+ if !gap.Ok() {
+ return false
+ }
+ if r.End > gap.End() {
+ return false
+ }
+ s.InsertWithoutMergingUnchecked(gap, r, val)
+ return true
+}
+
+// Insert inserts the given segment into the given gap. If the new segment can
+// be merged with adjacent segments, Insert will do so. Insert returns an
+// iterator to the segment containing the inserted value (which may have been
+// merged with other values). All existing iterators (including gap, but not
+// including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid, Insert panics.
+//
+// Insert is semantically equivalent to a InsertWithoutMerging followed by a
+// Merge, but may be more efficient. Note that there is no unchecked variant of
+// Insert since Insert must retrieve and inspect gap's predecessor and
+// successor segments regardless.
+func (s *addrSet) Insert(gap addrGapIterator, r addrRange, val *objectEncodeState) addrIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ prev, next := gap.PrevSegment(), gap.NextSegment()
+ if prev.Ok() && prev.End() > r.Start {
+ panic(fmt.Sprintf("new segment %v overlaps predecessor %v", r, prev.Range()))
+ }
+ if next.Ok() && next.Start() < r.End {
+ panic(fmt.Sprintf("new segment %v overlaps successor %v", r, next.Range()))
+ }
+ if prev.Ok() && prev.End() == r.Start {
+ if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ shrinkMaxGap := addrtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
+ prev.SetEndUnchecked(r.End)
+ prev.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
+ if next.Ok() && next.Start() == r.End {
+ val = mval
+ if mval, ok := (addrSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
+ prev.SetEndUnchecked(next.End())
+ prev.SetValue(mval)
+ return s.Remove(next).PrevSegment()
+ }
+ }
+ return prev
+ }
+ }
+ if next.Ok() && next.Start() == r.End {
+ if mval, ok := (addrSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ shrinkMaxGap := addrtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
+ next.SetStartUnchecked(r.Start)
+ next.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
+ return next
+ }
+ }
+
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMerging inserts the given segment into the given gap and
+// returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// If the gap cannot accommodate the segment, or if r is invalid,
+// InsertWithoutMerging panics.
+func (s *addrSet) InsertWithoutMerging(gap addrGapIterator, r addrRange, val *objectEncodeState) addrIterator {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if gr := gap.Range(); !gr.IsSupersetOf(r) {
+ panic(fmt.Sprintf("cannot insert segment range %v into gap range %v", r, gr))
+ }
+ return s.InsertWithoutMergingUnchecked(gap, r, val)
+}
+
+// InsertWithoutMergingUnchecked inserts the given segment into the given gap
+// and returns an iterator to the inserted segment. All existing iterators
+// (including gap, but not including the returned iterator) are invalidated.
+//
+// Preconditions:
+// * r.Start >= gap.Start().
+// * r.End <= gap.End().
+func (s *addrSet) InsertWithoutMergingUnchecked(gap addrGapIterator, r addrRange, val *objectEncodeState) addrIterator {
+ gap = gap.node.rebalanceBeforeInsert(gap)
+ splitMaxGap := addrtrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get())
+ copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
+ copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
+ gap.node.keys[gap.index] = r
+ gap.node.values[gap.index] = val
+ gap.node.nrSegments++
+ if splitMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
+ return addrIterator{gap.node, gap.index}
+}
+
+// Remove removes the given segment and returns an iterator to the vacated gap.
+// All existing iterators (including seg, but not including the returned
+// iterator) are invalidated.
+func (s *addrSet) Remove(seg addrIterator) addrGapIterator {
+
+ if seg.node.hasChildren {
+
+ victim := seg.PrevSegment()
+
+ seg.SetRangeUnchecked(victim.Range())
+ seg.SetValue(victim.Value())
+
+ nextAdjacentNode := seg.NextSegment().node
+ if addrtrackGaps != 0 {
+ nextAdjacentNode.updateMaxGapLeaf()
+ }
+ return s.Remove(victim).NextGap()
+ }
+ copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
+ copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
+ addrSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
+ seg.node.nrSegments--
+ if addrtrackGaps != 0 {
+ seg.node.updateMaxGapLeaf()
+ }
+ return seg.node.rebalanceAfterRemove(addrGapIterator{seg.node, seg.index})
+}
+
+// RemoveAll removes all segments from the set. All existing iterators are
+// invalidated.
+func (s *addrSet) RemoveAll() {
+ s.root = addrnode{}
+}
+
+// RemoveRange removes all segments in the given range. An iterator to the
+// newly formed gap is returned, and all existing iterators are invalidated.
+func (s *addrSet) RemoveRange(r addrRange) addrGapIterator {
+ seg, gap := s.Find(r.Start)
+ if seg.Ok() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ for seg = gap.NextSegment(); seg.Ok() && seg.Start() < r.End; seg = gap.NextSegment() {
+ seg = s.Isolate(seg, r)
+ gap = s.Remove(seg)
+ }
+ return gap
+}
+
+// Merge attempts to merge two neighboring segments. If successful, Merge
+// returns an iterator to the merged segment, and all existing iterators are
+// invalidated. Otherwise, Merge returns a terminal iterator.
+//
+// If first is not the predecessor of second, Merge panics.
+func (s *addrSet) Merge(first, second addrIterator) addrIterator {
+ if first.NextSegment() != second {
+ panic(fmt.Sprintf("attempt to merge non-neighboring segments %v, %v", first.Range(), second.Range()))
+ }
+ return s.MergeUnchecked(first, second)
+}
+
+// MergeUnchecked attempts to merge two neighboring segments. If successful,
+// MergeUnchecked returns an iterator to the merged segment, and all existing
+// iterators are invalidated. Otherwise, MergeUnchecked returns a terminal
+// iterator.
+//
+// Precondition: first is the predecessor of second: first.NextSegment() ==
+// second, first == second.PrevSegment().
+func (s *addrSet) MergeUnchecked(first, second addrIterator) addrIterator {
+ if first.End() == second.Start() {
+ if mval, ok := (addrSetFunctions{}).Merge(first.Range(), first.Value(), second.Range(), second.Value()); ok {
+
+ first.SetEndUnchecked(second.End())
+ first.SetValue(mval)
+
+ return s.Remove(second).PrevSegment()
+ }
+ }
+ return addrIterator{}
+}
+
+// MergeAll attempts to merge all adjacent segments in the set. All existing
+// iterators are invalidated.
+func (s *addrSet) MergeAll() {
+ seg := s.FirstSegment()
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeRange attempts to merge all adjacent segments that contain a key in the
+// specific range. All existing iterators are invalidated.
+func (s *addrSet) MergeRange(r addrRange) {
+ seg := s.LowerBoundSegment(r.Start)
+ if !seg.Ok() {
+ return
+ }
+ next := seg.NextSegment()
+ for next.Ok() && next.Range().Start < r.End {
+ if mseg := s.MergeUnchecked(seg, next); mseg.Ok() {
+ seg, next = mseg, mseg.NextSegment()
+ } else {
+ seg, next = next, next.NextSegment()
+ }
+ }
+}
+
+// MergeAdjacent attempts to merge the segment containing r.Start with its
+// predecessor, and the segment containing r.End-1 with its successor.
+func (s *addrSet) MergeAdjacent(r addrRange) {
+ first := s.FindSegment(r.Start)
+ if first.Ok() {
+ if prev := first.PrevSegment(); prev.Ok() {
+ s.Merge(prev, first)
+ }
+ }
+ last := s.FindSegment(r.End - 1)
+ if last.Ok() {
+ if next := last.NextSegment(); next.Ok() {
+ s.Merge(last, next)
+ }
+ }
+}
+
+// Split splits the given segment at the given key and returns iterators to the
+// two resulting segments. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+//
+// If the segment cannot be split at split (because split is at the start or
+// end of the segment's range, so splitting would produce a segment with zero
+// length, or because split falls outside the segment's range altogether),
+// Split panics.
+func (s *addrSet) Split(seg addrIterator, split uintptr) (addrIterator, addrIterator) {
+ if !seg.Range().CanSplitAt(split) {
+ panic(fmt.Sprintf("can't split %v at %v", seg.Range(), split))
+ }
+ return s.SplitUnchecked(seg, split)
+}
+
+// SplitUnchecked splits the given segment at the given key and returns
+// iterators to the two resulting segments. All existing iterators (including
+// seg, but not including the returned iterators) are invalidated.
+//
+// Preconditions: seg.Start() < key < seg.End().
+func (s *addrSet) SplitUnchecked(seg addrIterator, split uintptr) (addrIterator, addrIterator) {
+ val1, val2 := (addrSetFunctions{}).Split(seg.Range(), seg.Value(), split)
+ end2 := seg.End()
+ seg.SetEndUnchecked(split)
+ seg.SetValue(val1)
+ seg2 := s.InsertWithoutMergingUnchecked(seg.NextGap(), addrRange{split, end2}, val2)
+
+ return seg2.PrevSegment(), seg2
+}
+
+// SplitAt splits the segment straddling split, if one exists. SplitAt returns
+// true if a segment was split and false otherwise. If SplitAt splits a
+// segment, all existing iterators are invalidated.
+func (s *addrSet) SplitAt(split uintptr) bool {
+ if seg := s.FindSegment(split); seg.Ok() && seg.Range().CanSplitAt(split) {
+ s.SplitUnchecked(seg, split)
+ return true
+ }
+ return false
+}
+
+// Isolate ensures that the given segment's range does not escape r by
+// splitting at r.Start and r.End if necessary, and returns an updated iterator
+// to the bounded segment. All existing iterators (including seg, but not
+// including the returned iterators) are invalidated.
+func (s *addrSet) Isolate(seg addrIterator, r addrRange) addrIterator {
+ if seg.Range().CanSplitAt(r.Start) {
+ _, seg = s.SplitUnchecked(seg, r.Start)
+ }
+ if seg.Range().CanSplitAt(r.End) {
+ seg, _ = s.SplitUnchecked(seg, r.End)
+ }
+ return seg
+}
+
+// ApplyContiguous applies a function to a contiguous range of segments,
+// splitting if necessary. The function is applied until the first gap is
+// encountered, at which point the gap is returned. If the function is applied
+// across the entire range, a terminal gap is returned. All existing iterators
+// are invalidated.
+//
+// N.B. The Iterator must not be invalidated by the function.
+func (s *addrSet) ApplyContiguous(r addrRange, fn func(seg addrIterator)) addrGapIterator {
+ seg, gap := s.Find(r.Start)
+ if !seg.Ok() {
+ return gap
+ }
+ for {
+ seg = s.Isolate(seg, r)
+ fn(seg)
+ if seg.End() >= r.End {
+ return addrGapIterator{}
+ }
+ gap = seg.NextGap()
+ if !gap.IsEmpty() {
+ return gap
+ }
+ seg = gap.NextSegment()
+ if !seg.Ok() {
+
+ return addrGapIterator{}
+ }
+ }
+}
+
+// +stateify savable
+type addrnode struct {
+ // An internal binary tree node looks like:
+ //
+ // K
+ // / \
+ // Cl Cr
+ //
+ // where all keys in the subtree rooted by Cl (the left subtree) are less
+ // than K (the key of the parent node), and all keys in the subtree rooted
+ // by Cr (the right subtree) are greater than K.
+ //
+ // An internal B-tree node's indexes work out to look like:
+ //
+ // K0 K1 K2 ... Kn-1
+ // / \/ \/ \ ... / \
+ // C0 C1 C2 C3 ... Cn-1 Cn
+ //
+ // where n is nrSegments.
+ nrSegments int
+
+ // parent is a pointer to this node's parent. If this node is root, parent
+ // is nil.
+ parent *addrnode
+
+ // parentIndex is the index of this node in parent.children.
+ parentIndex int
+
+ // Flag for internal nodes that is technically redundant with "children[0]
+ // != nil", but is stored in the first cache line. "hasChildren" rather
+ // than "isLeaf" because false must be the correct value for an empty root.
+ hasChildren bool
+
+ // The longest gap within this node. If the node is a leaf, it's simply the
+ // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys
+ // including the 0th and nrSegments-th gap possibly shared with its upper-level
+ // nodes; if it's a non-leaf node, it's the max of all children's maxGap.
+ maxGap addrdynamicGap
+
+ // Nodes store keys and values in separate arrays to maximize locality in
+ // the common case (scanning keys for lookup).
+ keys [addrmaxDegree - 1]addrRange
+ values [addrmaxDegree - 1]*objectEncodeState
+ children [addrmaxDegree]*addrnode
+}
+
+// firstSegment returns the first segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *addrnode) firstSegment() addrIterator {
+ for n.hasChildren {
+ n = n.children[0]
+ }
+ return addrIterator{n, 0}
+}
+
+// lastSegment returns the last segment in the subtree rooted by n.
+//
+// Preconditions: n.nrSegments != 0.
+func (n *addrnode) lastSegment() addrIterator {
+ for n.hasChildren {
+ n = n.children[n.nrSegments]
+ }
+ return addrIterator{n, n.nrSegments - 1}
+}
+
+func (n *addrnode) prevSibling() *addrnode {
+ if n.parent == nil || n.parentIndex == 0 {
+ return nil
+ }
+ return n.parent.children[n.parentIndex-1]
+}
+
+func (n *addrnode) nextSibling() *addrnode {
+ if n.parent == nil || n.parentIndex == n.parent.nrSegments {
+ return nil
+ }
+ return n.parent.children[n.parentIndex+1]
+}
+
+// rebalanceBeforeInsert splits n and its ancestors if they are full, as
+// required for insertion, and returns an updated iterator to the position
+// represented by gap.
+func (n *addrnode) rebalanceBeforeInsert(gap addrGapIterator) addrGapIterator {
+ if n.nrSegments < addrmaxDegree-1 {
+ return gap
+ }
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
+ if n.parent == nil {
+
+ left := &addrnode{
+ nrSegments: addrminDegree - 1,
+ parent: n,
+ parentIndex: 0,
+ hasChildren: n.hasChildren,
+ }
+ right := &addrnode{
+ nrSegments: addrminDegree - 1,
+ parent: n,
+ parentIndex: 1,
+ hasChildren: n.hasChildren,
+ }
+ copy(left.keys[:addrminDegree-1], n.keys[:addrminDegree-1])
+ copy(left.values[:addrminDegree-1], n.values[:addrminDegree-1])
+ copy(right.keys[:addrminDegree-1], n.keys[addrminDegree:])
+ copy(right.values[:addrminDegree-1], n.values[addrminDegree:])
+ n.keys[0], n.values[0] = n.keys[addrminDegree-1], n.values[addrminDegree-1]
+ addrzeroValueSlice(n.values[1:])
+ if n.hasChildren {
+ copy(left.children[:addrminDegree], n.children[:addrminDegree])
+ copy(right.children[:addrminDegree], n.children[addrminDegree:])
+ addrzeroNodeSlice(n.children[2:])
+ for i := 0; i < addrminDegree; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ right.children[i].parent = right
+ right.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = 1
+ n.hasChildren = true
+ n.children[0] = left
+ n.children[1] = right
+
+ if addrtrackGaps != 0 {
+ left.updateMaxGapLocal()
+ right.updateMaxGapLocal()
+ }
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < addrminDegree {
+ return addrGapIterator{left, gap.index}
+ }
+ return addrGapIterator{right, gap.index - addrminDegree}
+ }
+
+ copy(n.parent.keys[n.parentIndex+1:], n.parent.keys[n.parentIndex:n.parent.nrSegments])
+ copy(n.parent.values[n.parentIndex+1:], n.parent.values[n.parentIndex:n.parent.nrSegments])
+ n.parent.keys[n.parentIndex], n.parent.values[n.parentIndex] = n.keys[addrminDegree-1], n.values[addrminDegree-1]
+ copy(n.parent.children[n.parentIndex+2:], n.parent.children[n.parentIndex+1:n.parent.nrSegments+1])
+ for i := n.parentIndex + 2; i < n.parent.nrSegments+2; i++ {
+ n.parent.children[i].parentIndex = i
+ }
+ sibling := &addrnode{
+ nrSegments: addrminDegree - 1,
+ parent: n.parent,
+ parentIndex: n.parentIndex + 1,
+ hasChildren: n.hasChildren,
+ }
+ n.parent.children[n.parentIndex+1] = sibling
+ n.parent.nrSegments++
+ copy(sibling.keys[:addrminDegree-1], n.keys[addrminDegree:])
+ copy(sibling.values[:addrminDegree-1], n.values[addrminDegree:])
+ addrzeroValueSlice(n.values[addrminDegree-1:])
+ if n.hasChildren {
+ copy(sibling.children[:addrminDegree], n.children[addrminDegree:])
+ addrzeroNodeSlice(n.children[addrminDegree:])
+ for i := 0; i < addrminDegree; i++ {
+ sibling.children[i].parent = sibling
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments = addrminDegree - 1
+
+ if addrtrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+
+ if gap.node != n {
+ return gap
+ }
+ if gap.index < addrminDegree {
+ return gap
+ }
+ return addrGapIterator{sibling, gap.index - addrminDegree}
+}
+
+// rebalanceAfterRemove "unsplits" n and its ancestors if they are deficient
+// (contain fewer segments than required by B-tree invariants), as required for
+// removal, and returns an updated iterator to the position represented by gap.
+//
+// Precondition: n is the only node in the tree that may currently violate a
+// B-tree invariant.
+func (n *addrnode) rebalanceAfterRemove(gap addrGapIterator) addrGapIterator {
+ for {
+ if n.nrSegments >= addrminDegree-1 {
+ return gap
+ }
+ if n.parent == nil {
+
+ return gap
+ }
+
+ if sibling := n.prevSibling(); sibling != nil && sibling.nrSegments >= addrminDegree {
+ copy(n.keys[1:], n.keys[:n.nrSegments])
+ copy(n.values[1:], n.values[:n.nrSegments])
+ n.keys[0] = n.parent.keys[n.parentIndex-1]
+ n.values[0] = n.parent.values[n.parentIndex-1]
+ n.parent.keys[n.parentIndex-1] = sibling.keys[sibling.nrSegments-1]
+ n.parent.values[n.parentIndex-1] = sibling.values[sibling.nrSegments-1]
+ addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ copy(n.children[1:], n.children[:n.nrSegments+1])
+ n.children[0] = sibling.children[sibling.nrSegments]
+ sibling.children[sibling.nrSegments] = nil
+ n.children[0].parent = n
+ n.children[0].parentIndex = 0
+ for i := 1; i < n.nrSegments+2; i++ {
+ n.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+
+ if addrtrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+ if gap.node == sibling && gap.index == sibling.nrSegments {
+ return addrGapIterator{n, 0}
+ }
+ if gap.node == n {
+ return addrGapIterator{n, gap.index + 1}
+ }
+ return gap
+ }
+ if sibling := n.nextSibling(); sibling != nil && sibling.nrSegments >= addrminDegree {
+ n.keys[n.nrSegments] = n.parent.keys[n.parentIndex]
+ n.values[n.nrSegments] = n.parent.values[n.parentIndex]
+ n.parent.keys[n.parentIndex] = sibling.keys[0]
+ n.parent.values[n.parentIndex] = sibling.values[0]
+ copy(sibling.keys[:sibling.nrSegments-1], sibling.keys[1:])
+ copy(sibling.values[:sibling.nrSegments-1], sibling.values[1:])
+ addrSetFunctions{}.ClearValue(&sibling.values[sibling.nrSegments-1])
+ if n.hasChildren {
+ n.children[n.nrSegments+1] = sibling.children[0]
+ copy(sibling.children[:sibling.nrSegments], sibling.children[1:])
+ sibling.children[sibling.nrSegments] = nil
+ n.children[n.nrSegments+1].parent = n
+ n.children[n.nrSegments+1].parentIndex = n.nrSegments + 1
+ for i := 0; i < sibling.nrSegments; i++ {
+ sibling.children[i].parentIndex = i
+ }
+ }
+ n.nrSegments++
+ sibling.nrSegments--
+
+ if addrtrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+ if gap.node == sibling {
+ if gap.index == 0 {
+ return addrGapIterator{n, n.nrSegments}
+ }
+ return addrGapIterator{sibling, gap.index - 1}
+ }
+ return gap
+ }
+
+ p := n.parent
+ if p.nrSegments == 1 {
+
+ left, right := p.children[0], p.children[1]
+ p.nrSegments = left.nrSegments + right.nrSegments + 1
+ p.hasChildren = left.hasChildren
+ p.keys[left.nrSegments] = p.keys[0]
+ p.values[left.nrSegments] = p.values[0]
+ copy(p.keys[:left.nrSegments], left.keys[:left.nrSegments])
+ copy(p.values[:left.nrSegments], left.values[:left.nrSegments])
+ copy(p.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(p.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(p.children[:left.nrSegments+1], left.children[:left.nrSegments+1])
+ copy(p.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := 0; i < p.nrSegments+1; i++ {
+ p.children[i].parent = p
+ p.children[i].parentIndex = i
+ }
+ } else {
+ p.children[0] = nil
+ p.children[1] = nil
+ }
+
+ if gap.node == left {
+ return addrGapIterator{p, gap.index}
+ }
+ if gap.node == right {
+ return addrGapIterator{p, gap.index + left.nrSegments + 1}
+ }
+ return gap
+ }
+ // Merge n and either sibling, along with the segment separating the
+ // two, into whichever of the two nodes comes first. This is the
+ // reverse of the non-root splitting case in
+ // node.rebalanceBeforeInsert.
+ var left, right *addrnode
+ if n.parentIndex > 0 {
+ left = n.prevSibling()
+ right = n
+ } else {
+ left = n
+ right = n.nextSibling()
+ }
+
+ if gap.node == right {
+ gap = addrGapIterator{left, gap.index + left.nrSegments + 1}
+ }
+ left.keys[left.nrSegments] = p.keys[left.parentIndex]
+ left.values[left.nrSegments] = p.values[left.parentIndex]
+ copy(left.keys[left.nrSegments+1:], right.keys[:right.nrSegments])
+ copy(left.values[left.nrSegments+1:], right.values[:right.nrSegments])
+ if left.hasChildren {
+ copy(left.children[left.nrSegments+1:], right.children[:right.nrSegments+1])
+ for i := left.nrSegments + 1; i < left.nrSegments+right.nrSegments+2; i++ {
+ left.children[i].parent = left
+ left.children[i].parentIndex = i
+ }
+ }
+ left.nrSegments += right.nrSegments + 1
+ copy(p.keys[left.parentIndex:], p.keys[left.parentIndex+1:p.nrSegments])
+ copy(p.values[left.parentIndex:], p.values[left.parentIndex+1:p.nrSegments])
+ addrSetFunctions{}.ClearValue(&p.values[p.nrSegments-1])
+ copy(p.children[left.parentIndex+1:], p.children[left.parentIndex+2:p.nrSegments+1])
+ for i := 0; i < p.nrSegments; i++ {
+ p.children[i].parentIndex = i
+ }
+ p.children[p.nrSegments] = nil
+ p.nrSegments--
+
+ if addrtrackGaps != 0 {
+ left.updateMaxGapLocal()
+ }
+
+ n = p
+ }
+}
+
+// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no
+// necessary update.
+//
+// Preconditions: n must be a leaf node, trackGaps must be 1.
+func (n *addrnode) updateMaxGapLeaf() {
+ if n.hasChildren {
+ panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n))
+ }
+ max := n.calculateMaxGapLeaf()
+ if max == n.maxGap.Get() {
+
+ return
+ }
+ oldMax := n.maxGap.Get()
+ n.maxGap.Set(max)
+ if max > oldMax {
+
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() >= max {
+
+ break
+ }
+
+ p.maxGap.Set(max)
+ }
+ return
+ }
+
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() > oldMax {
+
+ break
+ }
+
+ parentNewMax := p.calculateMaxGapInternal()
+ if p.maxGap.Get() == parentNewMax {
+
+ break
+ }
+
+ p.maxGap.Set(parentNewMax)
+ }
+}
+
+// updateMaxGapLocal updates maxGap of the calling node solely with no
+// propagation to ancestor nodes.
+//
+// Precondition: trackGaps must be 1.
+func (n *addrnode) updateMaxGapLocal() {
+ if !n.hasChildren {
+
+ n.maxGap.Set(n.calculateMaxGapLeaf())
+ } else {
+
+ n.maxGap.Set(n.calculateMaxGapInternal())
+ }
+}
+
+// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the
+// max.
+//
+// Preconditions: n must be a leaf node.
+func (n *addrnode) calculateMaxGapLeaf() uintptr {
+ max := addrGapIterator{n, 0}.Range().Length()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := (addrGapIterator{n, i}).Range().Length(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// calculateMaxGapInternal iterates children's maxGap within an internal node n
+// and calculate the max.
+//
+// Preconditions: n must be a non-leaf node.
+func (n *addrnode) calculateMaxGapInternal() uintptr {
+ max := n.children[0].maxGap.Get()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := n.children[i].maxGap.Get(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// searchFirstLargeEnoughGap returns the first gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *addrnode) searchFirstLargeEnoughGap(minSize uintptr) addrGapIterator {
+ if n.maxGap.Get() < minSize {
+ return addrGapIterator{}
+ }
+ if n.hasChildren {
+ for i := 0; i <= n.nrSegments; i++ {
+ if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ currentGap := addrGapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// searchLastLargeEnoughGap returns the last gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *addrnode) searchLastLargeEnoughGap(minSize uintptr) addrGapIterator {
+ if n.maxGap.Get() < minSize {
+ return addrGapIterator{}
+ }
+ if n.hasChildren {
+ for i := n.nrSegments; i >= 0; i-- {
+ if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := n.nrSegments; i >= 0; i-- {
+ currentGap := addrGapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// A Iterator is conceptually one of:
+//
+// - A pointer to a segment in a set; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Iterators are copyable values and are meaningfully equality-comparable. The
+// zero value of Iterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type addrIterator struct {
+ // node is the node containing the iterated segment. If the iterator is
+ // terminal, node is nil.
+ node *addrnode
+
+ // index is the index of the segment in node.keys/values.
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (seg addrIterator) Ok() bool {
+ return seg.node != nil
+}
+
+// Range returns the iterated segment's range key.
+func (seg addrIterator) Range() addrRange {
+ return seg.node.keys[seg.index]
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (seg addrIterator) Start() uintptr {
+ return seg.node.keys[seg.index].Start
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (seg addrIterator) End() uintptr {
+ return seg.node.keys[seg.index].End
+}
+
+// SetRangeUnchecked mutates the iterated segment's range key. This operation
+// does not invalidate any iterators.
+//
+// Preconditions:
+// * r.Length() > 0.
+// * The new range must not overlap an existing one:
+// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start().
+// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End().
+func (seg addrIterator) SetRangeUnchecked(r addrRange) {
+ seg.node.keys[seg.index] = r
+}
+
+// SetRange mutates the iterated segment's range key. If the new range would
+// cause the iterated segment to overlap another segment, or if the new range
+// is invalid, SetRange panics. This operation does not invalidate any
+// iterators.
+func (seg addrIterator) SetRange(r addrRange) {
+ if r.Length() <= 0 {
+ panic(fmt.Sprintf("invalid segment range %v", r))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && r.Start < prev.End() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, prev.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && r.End > next.Start() {
+ panic(fmt.Sprintf("new segment range %v overlaps segment range %v", r, next.Range()))
+ }
+ seg.SetRangeUnchecked(r)
+}
+
+// SetStartUnchecked mutates the iterated segment's start. This operation does
+// not invalidate any iterators.
+//
+// Preconditions: The new start must be valid:
+// * start < seg.End()
+// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End().
+func (seg addrIterator) SetStartUnchecked(start uintptr) {
+ seg.node.keys[seg.index].Start = start
+}
+
+// SetStart mutates the iterated segment's start. If the new start value would
+// cause the iterated segment to overlap another segment, or would result in an
+// invalid range, SetStart panics. This operation does not invalidate any
+// iterators.
+func (seg addrIterator) SetStart(start uintptr) {
+ if start >= seg.End() {
+ panic(fmt.Sprintf("new start %v would invalidate segment range %v", start, seg.Range()))
+ }
+ if prev := seg.PrevSegment(); prev.Ok() && start < prev.End() {
+ panic(fmt.Sprintf("new start %v would cause segment range %v to overlap segment range %v", start, seg.Range(), prev.Range()))
+ }
+ seg.SetStartUnchecked(start)
+}
+
+// SetEndUnchecked mutates the iterated segment's end. This operation does not
+// invalidate any iterators.
+//
+// Preconditions: The new end must be valid:
+// * end > seg.Start().
+// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start().
+func (seg addrIterator) SetEndUnchecked(end uintptr) {
+ seg.node.keys[seg.index].End = end
+}
+
+// SetEnd mutates the iterated segment's end. If the new end value would cause
+// the iterated segment to overlap another segment, or would result in an
+// invalid range, SetEnd panics. This operation does not invalidate any
+// iterators.
+func (seg addrIterator) SetEnd(end uintptr) {
+ if end <= seg.Start() {
+ panic(fmt.Sprintf("new end %v would invalidate segment range %v", end, seg.Range()))
+ }
+ if next := seg.NextSegment(); next.Ok() && end > next.Start() {
+ panic(fmt.Sprintf("new end %v would cause segment range %v to overlap segment range %v", end, seg.Range(), next.Range()))
+ }
+ seg.SetEndUnchecked(end)
+}
+
+// Value returns a copy of the iterated segment's value.
+func (seg addrIterator) Value() *objectEncodeState {
+ return seg.node.values[seg.index]
+}
+
+// ValuePtr returns a pointer to the iterated segment's value. The pointer is
+// invalidated if the iterator is invalidated. This operation does not
+// invalidate any iterators.
+func (seg addrIterator) ValuePtr() **objectEncodeState {
+ return &seg.node.values[seg.index]
+}
+
+// SetValue mutates the iterated segment's value. This operation does not
+// invalidate any iterators.
+func (seg addrIterator) SetValue(val *objectEncodeState) {
+ seg.node.values[seg.index] = val
+}
+
+// PrevSegment returns the iterated segment's predecessor. If there is no
+// preceding segment, PrevSegment returns a terminal iterator.
+func (seg addrIterator) PrevSegment() addrIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index].lastSegment()
+ }
+ if seg.index > 0 {
+ return addrIterator{seg.node, seg.index - 1}
+ }
+ if seg.node.parent == nil {
+ return addrIterator{}
+ }
+ return addrsegmentBeforePosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// NextSegment returns the iterated segment's successor. If there is no
+// succeeding segment, NextSegment returns a terminal iterator.
+func (seg addrIterator) NextSegment() addrIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment()
+ }
+ if seg.index < seg.node.nrSegments-1 {
+ return addrIterator{seg.node, seg.index + 1}
+ }
+ if seg.node.parent == nil {
+ return addrIterator{}
+ }
+ return addrsegmentAfterPosition(seg.node.parent, seg.node.parentIndex)
+}
+
+// PrevGap returns the gap immediately before the iterated segment.
+func (seg addrIterator) PrevGap() addrGapIterator {
+ if seg.node.hasChildren {
+
+ return seg.node.children[seg.index].lastSegment().NextGap()
+ }
+ return addrGapIterator{seg.node, seg.index}
+}
+
+// NextGap returns the gap immediately after the iterated segment.
+func (seg addrIterator) NextGap() addrGapIterator {
+ if seg.node.hasChildren {
+ return seg.node.children[seg.index+1].firstSegment().PrevGap()
+ }
+ return addrGapIterator{seg.node, seg.index + 1}
+}
+
+// PrevNonEmpty returns the iterated segment's predecessor if it is adjacent,
+// or the gap before the iterated segment otherwise. If seg.Start() ==
+// Functions.MinKey(), PrevNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by PrevNonEmpty will be
+// non-terminal.
+func (seg addrIterator) PrevNonEmpty() (addrIterator, addrGapIterator) {
+ gap := seg.PrevGap()
+ if gap.Range().Length() != 0 {
+ return addrIterator{}, gap
+ }
+ return gap.PrevSegment(), addrGapIterator{}
+}
+
+// NextNonEmpty returns the iterated segment's successor if it is adjacent, or
+// the gap after the iterated segment otherwise. If seg.End() ==
+// Functions.MaxKey(), NextNonEmpty will return two terminal iterators.
+// Otherwise, exactly one of the iterators returned by NextNonEmpty will be
+// non-terminal.
+func (seg addrIterator) NextNonEmpty() (addrIterator, addrGapIterator) {
+ gap := seg.NextGap()
+ if gap.Range().Length() != 0 {
+ return addrIterator{}, gap
+ }
+ return gap.NextSegment(), addrGapIterator{}
+}
+
+// A GapIterator is conceptually one of:
+//
+// - A pointer to a position between two segments, before the first segment, or
+// after the last segment in a set, called a *gap*; or
+//
+// - A terminal iterator, which is a sentinel indicating that the end of
+// iteration has been reached.
+//
+// Note that the gap between two adjacent segments exists (iterators to it are
+// non-terminal), but has a length of zero. GapIterator.IsEmpty returns true
+// for such gaps. An empty set contains a single gap, spanning the entire range
+// of the set's keys.
+//
+// GapIterators are copyable values and are meaningfully equality-comparable.
+// The zero value of GapIterator is a terminal iterator.
+//
+// Unless otherwise specified, any mutation of a set invalidates all existing
+// iterators into the set.
+type addrGapIterator struct {
+ // The representation of a GapIterator is identical to that of an Iterator,
+ // except that index corresponds to positions between segments in the same
+ // way as for node.children (see comment for node.nrSegments).
+ node *addrnode
+ index int
+}
+
+// Ok returns true if the iterator is not terminal. All other methods are only
+// valid for non-terminal iterators.
+func (gap addrGapIterator) Ok() bool {
+ return gap.node != nil
+}
+
+// Range returns the range spanned by the iterated gap.
+func (gap addrGapIterator) Range() addrRange {
+ return addrRange{gap.Start(), gap.End()}
+}
+
+// Start is equivalent to Range().Start, but should be preferred if only the
+// start of the range is needed.
+func (gap addrGapIterator) Start() uintptr {
+ if ps := gap.PrevSegment(); ps.Ok() {
+ return ps.End()
+ }
+ return addrSetFunctions{}.MinKey()
+}
+
+// End is equivalent to Range().End, but should be preferred if only the end of
+// the range is needed.
+func (gap addrGapIterator) End() uintptr {
+ if ns := gap.NextSegment(); ns.Ok() {
+ return ns.Start()
+ }
+ return addrSetFunctions{}.MaxKey()
+}
+
+// IsEmpty returns true if the iterated gap is empty (that is, the "gap" is
+// between two adjacent segments.)
+func (gap addrGapIterator) IsEmpty() bool {
+ return gap.Range().Length() == 0
+}
+
+// PrevSegment returns the segment immediately before the iterated gap. If no
+// such segment exists, PrevSegment returns a terminal iterator.
+func (gap addrGapIterator) PrevSegment() addrIterator {
+ return addrsegmentBeforePosition(gap.node, gap.index)
+}
+
+// NextSegment returns the segment immediately after the iterated gap. If no
+// such segment exists, NextSegment returns a terminal iterator.
+func (gap addrGapIterator) NextSegment() addrIterator {
+ return addrsegmentAfterPosition(gap.node, gap.index)
+}
+
+// PrevGap returns the iterated gap's predecessor. If no such gap exists,
+// PrevGap returns a terminal iterator.
+func (gap addrGapIterator) PrevGap() addrGapIterator {
+ seg := gap.PrevSegment()
+ if !seg.Ok() {
+ return addrGapIterator{}
+ }
+ return seg.PrevGap()
+}
+
+// NextGap returns the iterated gap's successor. If no such gap exists, NextGap
+// returns a terminal iterator.
+func (gap addrGapIterator) NextGap() addrGapIterator {
+ seg := gap.NextSegment()
+ if !seg.Ok() {
+ return addrGapIterator{}
+ }
+ return seg.NextGap()
+}
+
+// NextLargeEnoughGap returns the iterated gap's first next gap with larger
+// length than minSize. If not found, return a terminal gap iterator (does NOT
+// include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap addrGapIterator) NextLargeEnoughGap(minSize uintptr) addrGapIterator {
+ if addrtrackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments {
+
+ gap.node = gap.NextSegment().node
+ gap.index = 0
+ return gap.nextLargeEnoughGapHelper(minSize)
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the trailing gap of a non-leaf node.
+func (gap addrGapIterator) nextLargeEnoughGapHelper(minSize uintptr) addrGapIterator {
+
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+
+ if gap.node == nil {
+ return addrGapIterator{}
+ }
+
+ gap.index++
+ for gap.index <= gap.node.nrSegments {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index++
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == gap.node.nrSegments {
+
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or
+// equal length than minSize. If not found, return a terminal gap iterator
+// (does NOT include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap addrGapIterator) PrevLargeEnoughGap(minSize uintptr) addrGapIterator {
+ if addrtrackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == 0 {
+
+ gap.node = gap.PrevSegment().node
+ gap.index = gap.node.nrSegments
+ return gap.prevLargeEnoughGapHelper(minSize)
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the first gap of a non-leaf node.
+func (gap addrGapIterator) prevLargeEnoughGapHelper(minSize uintptr) addrGapIterator {
+
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+
+ if gap.node == nil {
+ return addrGapIterator{}
+ }
+
+ gap.index--
+ for gap.index >= 0 {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index--
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == 0 {
+
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// segmentBeforePosition returns the predecessor segment of the position given
+// by n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentBeforePosition returns a terminal iterator.
+func addrsegmentBeforePosition(n *addrnode, i int) addrIterator {
+ for i == 0 {
+ if n.parent == nil {
+ return addrIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return addrIterator{n, i - 1}
+}
+
+// segmentAfterPosition returns the successor segment of the position given by
+// n.children[i], which may or may not contain a child. If no such segment
+// exists, segmentAfterPosition returns a terminal iterator.
+func addrsegmentAfterPosition(n *addrnode, i int) addrIterator {
+ for i == n.nrSegments {
+ if n.parent == nil {
+ return addrIterator{}
+ }
+ n, i = n.parent, n.parentIndex
+ }
+ return addrIterator{n, i}
+}
+
+func addrzeroValueSlice(slice []*objectEncodeState) {
+
+ for i := range slice {
+ addrSetFunctions{}.ClearValue(&slice[i])
+ }
+}
+
+func addrzeroNodeSlice(slice []*addrnode) {
+ for i := range slice {
+ slice[i] = nil
+ }
+}
+
+// String stringifies a Set for debugging.
+func (s *addrSet) String() string {
+ return s.root.String()
+}
+
+// String stringifies a node (and all of its children) for debugging.
+func (n *addrnode) String() string {
+ var buf bytes.Buffer
+ n.writeDebugString(&buf, "")
+ return buf.String()
+}
+
+func (n *addrnode) writeDebugString(buf *bytes.Buffer, prefix string) {
+ if n.hasChildren != (n.nrSegments > 0 && n.children[0] != nil) {
+ buf.WriteString(prefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent value of hasChildren: got %v, want %v\n", n.hasChildren, !n.hasChildren))
+ }
+ for i := 0; i < n.nrSegments; i++ {
+ if child := n.children[i]; child != nil {
+ cprefix := fmt.Sprintf("%s- % 3d ", prefix, i)
+ if child.parent != n || child.parentIndex != i {
+ buf.WriteString(cprefix)
+ buf.WriteString(fmt.Sprintf("WARNING: inconsistent linkage to parent: got (%p, %d), want (%p, %d)\n", child.parent, child.parentIndex, n, i))
+ }
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
+ }
+ buf.WriteString(prefix)
+ if n.hasChildren {
+ if addrtrackGaps != 0 {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get()))
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ }
+ if child := n.children[n.nrSegments]; child != nil {
+ child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
+ }
+}
+
+// SegmentDataSlices represents segments from a set as slices of start, end, and
+// values. SegmentDataSlices is primarily used as an intermediate representation
+// for save/restore and the layout here is optimized for that.
+//
+// +stateify savable
+type addrSegmentDataSlices struct {
+ Start []uintptr
+ End []uintptr
+ Values []*objectEncodeState
+}
+
+// ExportSortedSlice returns a copy of all segments in the given set, in ascending
+// key order.
+func (s *addrSet) ExportSortedSlices() *addrSegmentDataSlices {
+ var sds addrSegmentDataSlices
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ sds.Start = append(sds.Start, seg.Start())
+ sds.End = append(sds.End, seg.End())
+ sds.Values = append(sds.Values, seg.Value())
+ }
+ sds.Start = sds.Start[:len(sds.Start):len(sds.Start)]
+ sds.End = sds.End[:len(sds.End):len(sds.End)]
+ sds.Values = sds.Values[:len(sds.Values):len(sds.Values)]
+ return &sds
+}
+
+// ImportSortedSlice initializes the given set from the given slice.
+//
+// Preconditions:
+// * s must be empty.
+// * sds must represent a valid set (the segments in sds must have valid
+// lengths that do not overlap).
+// * The segments in sds must be sorted in ascending key order.
+func (s *addrSet) ImportSortedSlices(sds *addrSegmentDataSlices) error {
+ if !s.IsEmpty() {
+ return fmt.Errorf("cannot import into non-empty set %v", s)
+ }
+ gap := s.FirstGap()
+ for i := range sds.Start {
+ r := addrRange{sds.Start[i], sds.End[i]}
+ if !gap.Range().IsSupersetOf(r) {
+ return fmt.Errorf("segment overlaps a preceding segment or is incorrectly sorted: [%d, %d) => %v", sds.Start[i], sds.End[i], sds.Values[i])
+ }
+ gap = s.InsertWithoutMerging(gap, r, sds.Values[i]).NextGap()
+ }
+ return nil
+}
+
+// segmentTestCheck returns an error if s is incorrectly sorted, does not
+// contain exactly expectedSegments segments, or contains a segment which
+// fails the passed check.
+//
+// This should be used only for testing, and has been added to this package for
+// templating convenience.
+func (s *addrSet) segmentTestCheck(expectedSegments int, segFunc func(int, addrRange, *objectEncodeState) error) error {
+ havePrev := false
+ prev := uintptr(0)
+ nrSegments := 0
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ next := seg.Start()
+ if havePrev && prev >= next {
+ return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
+ }
+ if segFunc != nil {
+ if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil {
+ return err
+ }
+ }
+ prev = next
+ havePrev = true
+ nrSegments++
+ }
+ if nrSegments != expectedSegments {
+ return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+ }
+ return nil
+}
+
+// countSegments counts the number of segments in the set.
+//
+// Similar to Check, this should only be used for testing.
+func (s *addrSet) countSegments() (segments int) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ segments++
+ }
+ return segments
+}
+func (s *addrSet) saveRoot() *addrSegmentDataSlices {
+ return s.ExportSortedSlices()
+}
+
+func (s *addrSet) loadRoot(sds *addrSegmentDataSlices) {
+ if err := s.ImportSortedSlices(sds); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/state/complete_list.go b/pkg/state/complete_list.go
new file mode 100644
index 000000000..4d340a1af
--- /dev/null
+++ b/pkg/state/complete_list.go
@@ -0,0 +1,221 @@
+package state
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type completeElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (completeElementMapper) linkerFor(elem *objectDecodeState) *objectDecodeState { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type completeList struct {
+ head *objectDecodeState
+ tail *objectDecodeState
+}
+
+// Reset resets list l to the empty state.
+func (l *completeList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *completeList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *completeList) Front() *objectDecodeState {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *completeList) Back() *objectDecodeState {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *completeList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (completeElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *completeList) PushFront(e *objectDecodeState) {
+ linker := completeElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ completeElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *completeList) PushBack(e *objectDecodeState) {
+ linker := completeElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ completeElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *completeList) PushBackList(m *completeList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ completeElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ completeElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *completeList) InsertAfter(b, e *objectDecodeState) {
+ bLinker := completeElementMapper{}.linkerFor(b)
+ eLinker := completeElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ completeElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *completeList) InsertBefore(a, e *objectDecodeState) {
+ aLinker := completeElementMapper{}.linkerFor(a)
+ eLinker := completeElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ completeElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *completeList) Remove(e *objectDecodeState) {
+ linker := completeElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ completeElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ completeElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type completeEntry struct {
+ next *objectDecodeState
+ prev *objectDecodeState
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *completeEntry) Next() *objectDecodeState {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *completeEntry) Prev() *objectDecodeState {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *completeEntry) SetNext(elem *objectDecodeState) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *completeEntry) SetPrev(elem *objectDecodeState) {
+ e.prev = elem
+}
diff --git a/pkg/state/deferred_list.go b/pkg/state/deferred_list.go
new file mode 100644
index 000000000..2753ce4a2
--- /dev/null
+++ b/pkg/state/deferred_list.go
@@ -0,0 +1,206 @@
+package state
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type deferredList struct {
+ head *objectEncodeState
+ tail *objectEncodeState
+}
+
+// Reset resets list l to the empty state.
+func (l *deferredList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *deferredList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *deferredList) Front() *objectEncodeState {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *deferredList) Back() *objectEncodeState {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *deferredList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (deferredMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *deferredList) PushFront(e *objectEncodeState) {
+ linker := deferredMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ deferredMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *deferredList) PushBack(e *objectEncodeState) {
+ linker := deferredMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ deferredMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *deferredList) PushBackList(m *deferredList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ deferredMapper{}.linkerFor(l.tail).SetNext(m.head)
+ deferredMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *deferredList) InsertAfter(b, e *objectEncodeState) {
+ bLinker := deferredMapper{}.linkerFor(b)
+ eLinker := deferredMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ deferredMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *deferredList) InsertBefore(a, e *objectEncodeState) {
+ aLinker := deferredMapper{}.linkerFor(a)
+ eLinker := deferredMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ deferredMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *deferredList) Remove(e *objectEncodeState) {
+ linker := deferredMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ deferredMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ deferredMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type deferredEntry struct {
+ next *objectEncodeState
+ prev *objectEncodeState
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *deferredEntry) Next() *objectEncodeState {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *deferredEntry) Prev() *objectEncodeState {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *deferredEntry) SetNext(elem *objectEncodeState) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *deferredEntry) SetPrev(elem *objectEncodeState) {
+ e.prev = elem
+}
diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD
deleted file mode 100644
index d053802f7..000000000
--- a/pkg/state/pretty/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "pretty",
- srcs = ["pretty.go"],
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/state",
- "//pkg/state/wire",
- ],
-)
diff --git a/pkg/state/pretty/pretty_state_autogen.go b/pkg/state/pretty/pretty_state_autogen.go
new file mode 100644
index 000000000..e772e34a4
--- /dev/null
+++ b/pkg/state/pretty/pretty_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package pretty
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
deleted file mode 100644
index d6c89c7e9..000000000
--- a/pkg/state/statefile/BUILD
+++ /dev/null
@@ -1,22 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "statefile",
- srcs = ["statefile.go"],
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/binary",
- "//pkg/compressio",
- "//pkg/state/wire",
- ],
-)
-
-go_test(
- name = "statefile_test",
- size = "small",
- srcs = ["statefile_test.go"],
- library = ":statefile",
- deps = ["//pkg/compressio"],
-)
diff --git a/pkg/state/statefile/statefile_state_autogen.go b/pkg/state/statefile/statefile_state_autogen.go
new file mode 100644
index 000000000..a2cdaa3f1
--- /dev/null
+++ b/pkg/state/statefile/statefile_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package statefile
diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go
deleted file mode 100644
index 0b470fdec..000000000
--- a/pkg/state/statefile/statefile_test.go
+++ /dev/null
@@ -1,290 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package statefile
-
-import (
- "bytes"
- crand "crypto/rand"
- "encoding/base64"
- "io"
- "math/rand"
- "runtime"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/compressio"
-)
-
-func randomKey() ([]byte, error) {
- r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize))
- if _, err := io.ReadFull(crand.Reader, r); err != nil {
- return nil, err
- }
- key := make([]byte, keySize)
- base64.RawStdEncoding.Encode(key, r)
- return key, nil
-}
-
-type testCase struct {
- name string
- data []byte
- metadata map[string]string
-}
-
-func TestStatefile(t *testing.T) {
- rand.Seed(time.Now().Unix())
-
- cases := []testCase{
- // Various data sizes.
- {"nil", nil, nil},
- {"empty", []byte(""), nil},
- {"some", []byte("_"), nil},
- {"one", []byte("0"), nil},
- {"two", []byte("01"), nil},
- {"three", []byte("012"), nil},
- {"four", []byte("0123"), nil},
- {"five", []byte("01234"), nil},
- {"six", []byte("012356"), nil},
- {"seven", []byte("0123567"), nil},
- {"eight", []byte("01235678"), nil},
-
- // Make sure we have one longer than the hash length.
- {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil},
-
- // Make sure we have one longer than the chunk size.
- {"chunks", make([]byte, 3*compressionChunkSize), nil},
- {"large", make([]byte, 30*compressionChunkSize), nil},
-
- // Different metadata.
- {"one metadata", []byte("data"), map[string]string{"foo": "bar"}},
- {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}},
- }
-
- for _, c := range cases {
- // Generate a key.
- integrityKey, err := randomKey()
- if err != nil {
- t.Errorf("can't generate key: got %v, excepted nil", err)
- continue
- }
-
- t.Run(c.name, func(t *testing.T) {
- for _, key := range [][]byte{nil, integrityKey} {
- t.Run("key="+string(key), func(t *testing.T) {
- // Encoding happens via a buffer.
- var bufEncoded bytes.Buffer
- var bufDecoded bytes.Buffer
-
- // Do all the writing.
- w, err := NewWriter(&bufEncoded, key, c.metadata)
- if err != nil {
- t.Fatalf("error creating writer: got %v, expected nil", err)
- }
- if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil {
- t.Fatalf("error during write: got %v, expected nil", err)
- }
-
- // Finish the sum.
- if err := w.Close(); err != nil {
- t.Fatalf("error during close: got %v, expected nil", err)
- }
-
- t.Logf("original data: %d bytes, encoded: %d bytes.",
- len(c.data), len(bufEncoded.Bytes()))
-
- // Do all the reading.
- r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
- if err != nil {
- t.Fatalf("error creating reader: got %v, expected nil", err)
- }
- if _, err := io.Copy(&bufDecoded, r); err != nil {
- t.Fatalf("error during read: got %v, expected nil", err)
- }
-
- // Check that the data matches.
- if !bytes.Equal(c.data, bufDecoded.Bytes()) {
- t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data))
- }
-
- // Check that the metadata matches.
- for k, v := range c.metadata {
- nv, ok := metadata[k]
- if !ok {
- t.Fatalf("missing metadata: %s", k)
- }
- if v != nv {
- t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v)
- }
- }
-
- // Change the data and verify that it fails.
- if key != nil {
- b := append([]byte(nil), bufEncoded.Bytes()...)
- b[rand.Intn(len(b))]++
- bufDecoded.Reset()
- r, _, err = NewReader(bytes.NewReader(b), key)
- if err == nil {
- _, err = io.Copy(&bufDecoded, r)
- }
- if err == nil {
- t.Error("got no error: expected error on data corruption")
- }
- }
-
- // Change the key and verify that it fails.
- newKey := integrityKey
- if len(key) > 0 {
- newKey = append([]byte{}, key...)
- newKey[rand.Intn(len(newKey))]++
- }
- bufDecoded.Reset()
- r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey)
- if err == nil {
- _, err = io.Copy(&bufDecoded, r)
- }
- if err != compressio.ErrHashMismatch {
- t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err)
- }
- })
- }
- })
- }
-}
-
-const benchmarkDataSize = 100 * 1024 * 1024
-
-func benchmark(b *testing.B, size int, write bool, compressible bool) {
- b.StopTimer()
- b.SetBytes(benchmarkDataSize)
-
- // Generate source data.
- var source []byte
- if compressible {
- // For compressible data, we use essentially all zeros.
- source = make([]byte, benchmarkDataSize)
- } else {
- // For non-compressible data, we use random base64 data (to
- // make it marginally compressible, a ratio of 75%).
- var sourceBuf bytes.Buffer
- bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf)
- bufR := rand.New(rand.NewSource(0))
- if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil {
- b.Fatalf("unable to seed random data: %v", err)
- }
- source = sourceBuf.Bytes()
- }
-
- // Generate a random key for integrity check.
- key, err := randomKey()
- if err != nil {
- b.Fatalf("error generating key: %v", err)
- }
-
- // Define our benchmark functions. Prior to running the readState
- // function here, you must execute the writeState function at least
- // once (done below).
- var stateBuf bytes.Buffer
- writeState := func() {
- stateBuf.Reset()
- w, err := NewWriter(&stateBuf, key, nil)
- if err != nil {
- b.Fatalf("error creating writer: %v", err)
- }
- for done := 0; done < len(source); {
- chunk := size // limit size.
- if done+chunk > len(source) {
- chunk = len(source) - done
- }
- n, err := w.Write(source[done : done+chunk])
- done += n
- if n == 0 && err != nil {
- b.Fatalf("error during write: %v", err)
- }
- }
- if err := w.Close(); err != nil {
- b.Fatalf("error closing writer: %v", err)
- }
- }
- readState := func() {
- tmpBuf := bytes.NewBuffer(stateBuf.Bytes())
- r, _, err := NewReader(tmpBuf, key)
- if err != nil {
- b.Fatalf("error creating reader: %v", err)
- }
- for done := 0; done < len(source); {
- chunk := size // limit size.
- if done+chunk > len(source) {
- chunk = len(source) - done
- }
- n, err := r.Read(source[done : done+chunk])
- done += n
- if n == 0 && err != nil {
- b.Fatalf("error during read: %v", err)
- }
- }
- }
- // Generate the state once without timing to ensure that buffers have
- // been appropriately allocated.
- writeState()
- if write {
- b.StartTimer()
- for i := 0; i < b.N; i++ {
- writeState()
- }
- b.StopTimer()
- } else {
- b.StartTimer()
- for i := 0; i < b.N; i++ {
- readState()
- }
- b.StopTimer()
- }
-}
-
-func BenchmarkWrite4KCompressible(b *testing.B) {
- benchmark(b, 4096, true, true)
-}
-
-func BenchmarkWrite4KNoncompressible(b *testing.B) {
- benchmark(b, 4096, true, false)
-}
-
-func BenchmarkWrite1MCompressible(b *testing.B) {
- benchmark(b, 1024*1024, true, true)
-}
-
-func BenchmarkWrite1MNoncompressible(b *testing.B) {
- benchmark(b, 1024*1024, true, false)
-}
-
-func BenchmarkRead4KCompressible(b *testing.B) {
- benchmark(b, 4096, false, true)
-}
-
-func BenchmarkRead4KNoncompressible(b *testing.B) {
- benchmark(b, 4096, false, false)
-}
-
-func BenchmarkRead1MCompressible(b *testing.B) {
- benchmark(b, 1024*1024, false, true)
-}
-
-func BenchmarkRead1MNoncompressible(b *testing.B) {
- benchmark(b, 1024*1024, false, false)
-}
-
-func init() {
- runtime.GOMAXPROCS(runtime.NumCPU())
-}
diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD
deleted file mode 100644
index 9297cafbe..000000000
--- a/pkg/state/tests/BUILD
+++ /dev/null
@@ -1,43 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tests",
- srcs = [
- "array.go",
- "bench.go",
- "integer.go",
- "load.go",
- "map.go",
- "register.go",
- "struct.go",
- "tests.go",
- ],
- deps = [
- "//pkg/state",
- "//pkg/state/pretty",
- ],
-)
-
-go_test(
- name = "tests_test",
- size = "small",
- srcs = [
- "array_test.go",
- "bench_test.go",
- "bool_test.go",
- "float_test.go",
- "integer_test.go",
- "load_test.go",
- "map_test.go",
- "register_test.go",
- "string_test.go",
- "struct_test.go",
- ],
- library = ":tests",
- deps = [
- "//pkg/state",
- "//pkg/state/wire",
- ],
-)
diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go
deleted file mode 100644
index 0972a80e7..000000000
--- a/pkg/state/tests/array.go
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-// +stateify savable
-type arrayContainer struct {
- v [1]interface{}
-}
-
-// +stateify savable
-type arrayPtrContainer struct {
- v *[1]interface{}
-}
-
-// +stateify savable
-type sliceContainer struct {
- v []interface{}
-}
-
-// +stateify savable
-type slicePtrContainer struct {
- v *[]interface{}
-}
diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go
deleted file mode 100644
index a347b2947..000000000
--- a/pkg/state/tests/array_test.go
+++ /dev/null
@@ -1,134 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "reflect"
- "testing"
-)
-
-var allArrayPrimitives = []interface{}{
- [1]bool{},
- [1]bool{true},
- [2]bool{false, true},
- [1]int{},
- [1]int{1},
- [2]int{0, 1},
- [1]int8{},
- [1]int8{1},
- [2]int8{0, 1},
- [1]int16{},
- [1]int16{1},
- [2]int16{0, 1},
- [1]int32{},
- [1]int32{1},
- [2]int32{0, 1},
- [1]int64{},
- [1]int64{1},
- [2]int64{0, 1},
- [1]uint{},
- [1]uint{1},
- [2]uint{0, 1},
- [1]uintptr{},
- [1]uintptr{1},
- [2]uintptr{0, 1},
- [1]uint8{},
- [1]uint8{1},
- [2]uint8{0, 1},
- [1]uint16{},
- [1]uint16{1},
- [2]uint16{0, 1},
- [1]uint32{},
- [1]uint32{1},
- [2]uint32{0, 1},
- [1]uint64{},
- [1]uint64{1},
- [2]uint64{0, 1},
- [1]string{},
- [1]string{""},
- [1]string{nonEmptyString},
- [2]string{"", nonEmptyString},
-}
-
-func TestArrayPrimitives(t *testing.T) {
- runTestCases(t, false, "plain", flatten(allArrayPrimitives))
- runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives)))
- runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives)))
- runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives))))
-}
-
-func TestSlices(t *testing.T) {
- var allSlices = flatten(
- filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
- v := reflect.New(reflect.TypeOf(o)).Elem()
- v.Set(reflect.ValueOf(o))
- return v.Slice(0, v.Len()).Interface(), true
- }),
- filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
- v := reflect.New(reflect.TypeOf(o)).Elem()
- v.Set(reflect.ValueOf(o))
- if v.Len() == 0 {
- // Return the pure "nil" value for the slice.
- return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true
- }
- return v.Slice(1, v.Len()).Interface(), true
- }),
- filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
- v := reflect.New(reflect.TypeOf(o)).Elem()
- v.Set(reflect.ValueOf(o))
- if v.Len() == 0 {
- // Return the zero-valued slice.
- return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true
- }
- return v.Slice(0, v.Len()-1).Interface(), true
- }),
- )
- runTestCases(t, false, "plain", allSlices)
- runTestCases(t, false, "pointers", pointersTo(allSlices))
- runTestCases(t, false, "interfaces", interfacesTo(allSlices))
- runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices)))
-}
-
-func TestArrayContainers(t *testing.T) {
- var (
- emptyArray [1]interface{}
- fullArray [1]interface{}
- )
- fullArray[0] = &emptyArray
- runTestCases(t, false, "", []interface{}{
- arrayContainer{v: emptyArray},
- arrayContainer{v: fullArray},
- arrayPtrContainer{v: nil},
- arrayPtrContainer{v: &emptyArray},
- arrayPtrContainer{v: &fullArray},
- })
-}
-
-func TestSliceContainers(t *testing.T) {
- var (
- nilSlice []interface{}
- emptySlice = make([]interface{}, 0)
- fullSlice = []interface{}{nil}
- )
- runTestCases(t, false, "", []interface{}{
- sliceContainer{v: nilSlice},
- sliceContainer{v: emptySlice},
- sliceContainer{v: fullSlice},
- slicePtrContainer{v: nil},
- slicePtrContainer{v: &nilSlice},
- slicePtrContainer{v: &emptySlice},
- slicePtrContainer{v: &fullSlice},
- })
-}
diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go
deleted file mode 100644
index 40869cdfb..000000000
--- a/pkg/state/tests/bench.go
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-// +stateify savable
-type benchStruct struct {
- B *benchStruct // Must be exported for gob.
-}
-
-func (b *benchStruct) afterLoad() {
- // Do nothing, just force scheduling.
-}
diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go
deleted file mode 100644
index 7e102c907..000000000
--- a/pkg/state/tests/bench_test.go
+++ /dev/null
@@ -1,153 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "bytes"
- "context"
- "encoding/gob"
- "fmt"
- "testing"
-
- "gvisor.dev/gvisor/pkg/state"
- "gvisor.dev/gvisor/pkg/state/wire"
-)
-
-// buildPtrObject builds a benchmark object.
-func buildPtrObject(n int) interface{} {
- b := new(benchStruct)
- for i := 0; i < n; i++ {
- b = &benchStruct{B: b}
- }
- return b
-}
-
-// buildMapObject builds a benchmark object.
-func buildMapObject(n int) interface{} {
- b := new(benchStruct)
- m := make(map[int]*benchStruct)
- for i := 0; i < n; i++ {
- m[i] = b
- }
- return &m
-}
-
-// buildSliceObject builds a benchmark object.
-func buildSliceObject(n int) interface{} {
- b := new(benchStruct)
- s := make([]*benchStruct, 0, n)
- for i := 0; i < n; i++ {
- s = append(s, b)
- }
- return &s
-}
-
-var allObjects = map[string]struct {
- New func(int) interface{}
-}{
- "ptr": {
- New: buildPtrObject,
- },
- "map": {
- New: buildMapObject,
- },
- "slice": {
- New: buildSliceObject,
- },
-}
-
-func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) {
- // maxSize is the maximum size of an individual object below. For an N
- // larger than this, we start to return multiple objects.
- const maxSize = 1024
- if n <= maxSize {
- return 1, fn(n)
- }
- iters = (n + maxSize - 1) / maxSize
- return iters, fn(maxSize)
-}
-
-// gobSave is a version of save using gob (no stats available).
-func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) {
- enc := gob.NewEncoder(w)
- err = enc.Encode(v)
- return
-}
-
-// gobLoad is a version of load using gob (no stats available).
-func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) {
- dec := gob.NewDecoder(r)
- err = dec.Decode(v)
- return
-}
-
-var allAlgos = map[string]struct {
- Save func(context.Context, wire.Writer, interface{}) (state.Stats, error)
- Load func(context.Context, wire.Reader, interface{}) (state.Stats, error)
- MaxPtr int
-}{
- "state": {
- Save: state.Save,
- Load: state.Load,
- },
- "gob": {
- Save: gobSave,
- Load: gobLoad,
- },
-}
-
-func BenchmarkEncoding(b *testing.B) {
- for objName, objInfo := range allObjects {
- for algoName, algoInfo := range allAlgos {
- b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
- b.StopTimer()
- n, v := buildObjects(b.N, objInfo.New)
- b.ReportAllocs()
- b.StartTimer()
- for i := 0; i < n; i++ {
- if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil {
- b.Errorf("save failed: %v", err)
- }
- }
- b.StopTimer()
- })
- }
- }
-}
-
-func BenchmarkDecoding(b *testing.B) {
- for objName, objInfo := range allObjects {
- for algoName, algoInfo := range allAlgos {
- b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
- b.StopTimer()
- n, v := buildObjects(b.N, objInfo.New)
- buf := new(bytes.Buffer)
- if _, err := algoInfo.Save(context.Background(), buf, v); err != nil {
- b.Errorf("save failed: %v", err)
- }
- b.ReportAllocs()
- b.StartTimer()
- var r bytes.Reader
- for i := 0; i < n; i++ {
- r.Reset(buf.Bytes())
- if _, err := algoInfo.Load(context.Background(), &r, v); err != nil {
- b.Errorf("load failed: %v", err)
- }
- }
- b.StopTimer()
- })
- }
- }
-}
diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go
deleted file mode 100644
index e17cfacf9..000000000
--- a/pkg/state/tests/bool_test.go
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "testing"
-)
-
-var allBools = []bool{
- true,
- false,
-}
-
-func TestBool(t *testing.T) {
- runTestCases(t, false, "plain", flatten(allBools))
- runTestCases(t, false, "pointers", pointersTo(flatten(allBools)))
- runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools)))
- runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools))))
-}
diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go
deleted file mode 100644
index 3e89edd9c..000000000
--- a/pkg/state/tests/float_test.go
+++ /dev/null
@@ -1,118 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "math"
- "testing"
-)
-
-var safeFloat32s = []float32{
- float32(0.0),
- float32(1.0),
- float32(-1.0),
- float32(math.Inf(1)),
- float32(math.Inf(-1)),
-}
-
-var allFloat32s = append(safeFloat32s, float32(math.NaN()))
-
-var safeFloat64s = []float64{
- float64(0.0),
- float64(1.0),
- float64(-1.0),
- math.Inf(1),
- math.Inf(-1),
-}
-
-var allFloat64s = append(safeFloat64s, math.NaN())
-
-func TestFloat(t *testing.T) {
- runTestCases(t, false, "plain", flatten(
- allFloat32s,
- allFloat64s,
- ))
- // See checkEqual for why NaNs are missing.
- runTestCases(t, false, "pointers", pointersTo(flatten(
- safeFloat32s,
- safeFloat64s,
- )))
- runTestCases(t, false, "interfaces", interfacesTo(flatten(
- safeFloat32s,
- safeFloat64s,
- )))
- runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(
- safeFloat32s,
- safeFloat64s,
- ))))
-}
-
-const onlyDouble float64 = 1.0000000000000002
-
-func TestFloatTruncation(t *testing.T) {
- runTestCases(t, true, "pass", []interface{}{
- truncatingFloat32{save: onlyDouble},
- })
- runTestCases(t, false, "fail", []interface{}{
- truncatingFloat32{save: 1.0},
- })
-}
-
-var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} {
- return complex(i.(float32), j.(float32))
-})
-
-var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} {
- return complex(i.(float32), j.(float32))
-})
-
-var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} {
- return complex(i.(float64), j.(float64))
-})
-
-var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} {
- return complex(i.(float64), j.(float64))
-})
-
-func TestComplex(t *testing.T) {
- runTestCases(t, false, "plain", flatten(
- allComplex64s,
- allComplex128s,
- ))
- // See TestFloat; same issue.
- runTestCases(t, false, "pointers", pointersTo(flatten(
- safeComplex64s,
- safeComplex128s,
- )))
- runTestCases(t, false, "interfacse", interfacesTo(flatten(
- safeComplex64s,
- safeComplex128s,
- )))
- runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten(
- safeComplex64s,
- safeComplex128s,
- ))))
-}
-
-func TestComplexTruncation(t *testing.T) {
- runTestCases(t, true, "pass", []interface{}{
- truncatingComplex64{save: complex(onlyDouble, onlyDouble)},
- truncatingComplex64{save: complex(1.0, onlyDouble)},
- truncatingComplex64{save: complex(onlyDouble, 1.0)},
- })
- runTestCases(t, false, "fail", []interface{}{
- truncatingComplex64{save: complex(1.0, 1.0)},
- })
-}
diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go
deleted file mode 100644
index ca403eed1..000000000
--- a/pkg/state/tests/integer.go
+++ /dev/null
@@ -1,163 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "gvisor.dev/gvisor/pkg/state"
-)
-
-// +stateify type
-type truncatingUint8 struct {
- save uint64
- load uint8 `state:"nosave"`
-}
-
-func (t *truncatingUint8) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingUint8) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = uint64(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingUint8)(nil)
-
-// +stateify type
-type truncatingUint16 struct {
- save uint64
- load uint16 `state:"nosave"`
-}
-
-func (t *truncatingUint16) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingUint16) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = uint64(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingUint16)(nil)
-
-// +stateify type
-type truncatingUint32 struct {
- save uint64
- load uint32 `state:"nosave"`
-}
-
-func (t *truncatingUint32) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingUint32) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = uint64(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingUint32)(nil)
-
-// +stateify type
-type truncatingInt8 struct {
- save int64
- load int8 `state:"nosave"`
-}
-
-func (t *truncatingInt8) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingInt8) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = int64(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingInt8)(nil)
-
-// +stateify type
-type truncatingInt16 struct {
- save int64
- load int16 `state:"nosave"`
-}
-
-func (t *truncatingInt16) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingInt16) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = int64(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingInt16)(nil)
-
-// +stateify type
-type truncatingInt32 struct {
- save int64
- load int32 `state:"nosave"`
-}
-
-func (t *truncatingInt32) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingInt32) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = int64(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingInt32)(nil)
-
-// +stateify type
-type truncatingFloat32 struct {
- save float64
- load float32 `state:"nosave"`
-}
-
-func (t *truncatingFloat32) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingFloat32) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = float64(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingFloat32)(nil)
-
-// +stateify type
-type truncatingComplex64 struct {
- save complex128
- load complex64 `state:"nosave"`
-}
-
-func (t *truncatingComplex64) StateSave(m state.Sink) {
- m.Save(0, &t.save)
-}
-
-func (t *truncatingComplex64) StateLoad(m state.Source) {
- m.Load(0, &t.load)
- t.save = complex128(t.load)
- t.load = 0
-}
-
-var _ state.SaverLoader = (*truncatingComplex64)(nil)
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
deleted file mode 100644
index 2b1609af0..000000000
--- a/pkg/state/tests/integer_test.go
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "math"
- "testing"
-)
-
-var (
- allBasicInts = []int{-1, 0, 1}
- allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
- allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
- allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
- allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
- allBasicUints = []uint{0, 1}
- allUintptrs = []uintptr{0, 1, ^uintptr(0)}
- allUint8s = []uint8{0, 1, math.MaxUint8}
- allUint16s = []uint16{0, 1, math.MaxUint16}
- allUint32s = []uint32{0, 1, math.MaxUint32}
- allUint64s = []uint64{0, 1, math.MaxUint64}
-)
-
-var allInts = flatten(
- allBasicInts,
- allInt8s,
- allInt16s,
- allInt32s,
- allInt64s,
-)
-
-var allUints = flatten(
- allBasicUints,
- allUintptrs,
- allUint8s,
- allUint16s,
- allUint32s,
- allUint64s,
-)
-
-func TestInt(t *testing.T) {
- runTestCases(t, false, "plain", allInts)
- runTestCases(t, false, "pointers", pointersTo(allInts))
- runTestCases(t, false, "interfaces", interfacesTo(allInts))
- runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts)))
-}
-
-func TestIntTruncation(t *testing.T) {
- runTestCases(t, true, "pass", []interface{}{
- truncatingInt8{save: math.MinInt8 - 1},
- truncatingInt16{save: math.MinInt16 - 1},
- truncatingInt32{save: math.MinInt32 - 1},
- truncatingInt8{save: math.MaxInt8 + 1},
- truncatingInt16{save: math.MaxInt16 + 1},
- truncatingInt32{save: math.MaxInt32 + 1},
- })
- runTestCases(t, false, "fail", []interface{}{
- truncatingInt8{save: 1},
- truncatingInt16{save: 1},
- truncatingInt32{save: 1},
- })
-}
-
-func TestUint(t *testing.T) {
- runTestCases(t, false, "plain", allUints)
- runTestCases(t, false, "pointers", pointersTo(allUints))
- runTestCases(t, false, "interfaces", interfacesTo(allUints))
- runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints)))
-}
-
-func TestUintTruncation(t *testing.T) {
- runTestCases(t, true, "pass", []interface{}{
- truncatingUint8{save: math.MaxUint8 + 1},
- truncatingUint16{save: math.MaxUint16 + 1},
- truncatingUint32{save: math.MaxUint32 + 1},
- })
- runTestCases(t, false, "fail", []interface{}{
- truncatingUint8{save: 1},
- truncatingUint16{save: 1},
- truncatingUint32{save: 1},
- })
-}
diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go
deleted file mode 100644
index a8350c0f3..000000000
--- a/pkg/state/tests/load.go
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-// +stateify savable
-type genericContainer struct {
- v interface{}
-}
-
-// +stateify savable
-type afterLoadStruct struct {
- v int `state:"nosave"`
-}
-
-func (a *afterLoadStruct) afterLoad() {
- a.v++
-}
-
-// +stateify savable
-type valueLoadStruct struct {
- v int `state:".(int64)"`
-}
-
-func (v *valueLoadStruct) saveV() int64 {
- return int64(v.v) // Save as int64.
-}
-
-func (v *valueLoadStruct) loadV(value int64) {
- v.v = int(value) // Load as int.
-}
-
-// +stateify savable
-type cycleStruct struct {
- c *cycleStruct
-}
-
-// +stateify savable
-type badCycleStruct struct {
- b *badCycleStruct `state:"wait"`
-}
-
-func (b *badCycleStruct) afterLoad() {
- if b.b != b {
- // This is not executable, since AfterLoad requires that the
- // object and all dependencies are complete. This should cause
- // a deadlock error during load.
- panic("badCycleStruct.afterLoad called")
- }
-}
diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go
deleted file mode 100644
index 3c73ac391..000000000
--- a/pkg/state/tests/load_test.go
+++ /dev/null
@@ -1,78 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "testing"
-)
-
-func TestLoadHooks(t *testing.T) {
- runTestCases(t, false, "load-hooks", []interface{}{
- // Root object being a struct.
- afterLoadStruct{v: 1},
- valueLoadStruct{v: 1},
- genericContainer{v: &afterLoadStruct{v: 1}},
- genericContainer{v: &valueLoadStruct{v: 1}},
- sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
- sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
- // Root object being a pointer.
- &afterLoadStruct{v: 1},
- &valueLoadStruct{v: 1},
- &genericContainer{v: &afterLoadStruct{v: 1}},
- &genericContainer{v: &valueLoadStruct{v: 1}},
- &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
- &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
- })
-}
-
-func TestCycles(t *testing.T) {
- // cs is a single object cycle.
- cs := cycleStruct{nil}
- cs.c = &cs
-
- // cs1 and cs2 are in a two object cycle.
- cs1 := cycleStruct{nil}
- cs2 := cycleStruct{nil}
- cs1.c = &cs2
- cs2.c = &cs1
-
- runTestCases(t, false, "cycles", []interface{}{
- cs,
- cs1,
- })
-}
-
-func TestDeadlock(t *testing.T) {
- // bs is a single object cycle. This does not cause deadlock because an
- // object cannot wait for itself.
- bs := badCycleStruct{nil}
- bs.b = &bs
-
- runTestCases(t, false, "self", []interface{}{
- &bs,
- })
-
- // bs2 and bs2 are in a deadlocking cycle.
- bs1 := badCycleStruct{nil}
- bs2 := badCycleStruct{nil}
- bs1.b = &bs2
- bs2.b = &bs1
-
- runTestCases(t, true, "deadlock", []interface{}{
- &bs1,
- })
-}
diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go
deleted file mode 100644
index db4e548f1..000000000
--- a/pkg/state/tests/map.go
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-// +stateify savable
-type mapContainer struct {
- v map[int]interface{}
-}
-
-// +stateify savable
-type mapPtrContainer struct {
- v *map[int]interface{}
-}
-
-// +stateify savable
-type registeredMapStruct struct{}
diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go
deleted file mode 100644
index 92bf0fc01..000000000
--- a/pkg/state/tests/map_test.go
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "reflect"
- "testing"
-)
-
-var allMapPrimitives = []interface{}{
- bool(true),
- int(1),
- int8(1),
- int16(1),
- int32(1),
- int64(1),
- uint(1),
- uintptr(1),
- uint8(1),
- uint16(1),
- uint32(1),
- uint64(1),
- string(""),
- registeredMapStruct{},
-}
-
-var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives))
-
-var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives))
-
-var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
- m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
- return m.Interface()
-})
-
-var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
- m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
- m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2)))
- return m.Interface()
-})
-
-func TestMapAliasing(t *testing.T) {
- v := make(map[int]int)
- ptrToV := &v
- aliases := []map[int]int{v, v}
- runTestCases(t, false, "", []interface{}{ptrToV, aliases})
-}
-
-func TestMapsEmpty(t *testing.T) {
- runTestCases(t, false, "plain", emptyMaps)
- runTestCases(t, false, "pointers", pointersTo(emptyMaps))
- runTestCases(t, false, "interfaces", interfacesTo(emptyMaps))
- runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps)))
-}
-
-func TestMapsFull(t *testing.T) {
- runTestCases(t, false, "plain", fullMaps)
- runTestCases(t, false, "pointers", pointersTo(fullMaps))
- runTestCases(t, false, "interfaces", interfacesTo(fullMaps))
- runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps)))
-}
-
-func TestMapContainers(t *testing.T) {
- var (
- nilMap map[int]interface{}
- emptyMap = make(map[int]interface{})
- fullMap = map[int]interface{}{0: nil}
- )
- runTestCases(t, false, "", []interface{}{
- mapContainer{v: nilMap},
- mapContainer{v: emptyMap},
- mapContainer{v: fullMap},
- mapPtrContainer{v: nil},
- mapPtrContainer{v: &nilMap},
- mapPtrContainer{v: &emptyMap},
- mapPtrContainer{v: &fullMap},
- })
-}
diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go
deleted file mode 100644
index 074d86315..000000000
--- a/pkg/state/tests/register.go
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-// +stateify savable
-type alreadyRegisteredStruct struct{}
-
-// +stateify savable
-type alreadyRegisteredOther int
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
deleted file mode 100644
index 75bdbfc6e..000000000
--- a/pkg/state/tests/register_test.go
+++ /dev/null
@@ -1,178 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build race
-
-package tests
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/state"
-)
-
-// faker calls itself whatever is in the name field.
-type faker struct {
- Name string
- Fields []string
-}
-
-func (f *faker) StateTypeName() string {
- return f.Name
-}
-
-func (f *faker) StateFields() []string {
- return f.Fields
-}
-
-// fakerWithSaverLoader has all it needs.
-type fakerWithSaverLoader struct {
- faker
-}
-
-func (f *fakerWithSaverLoader) StateSave(m state.Sink) {}
-
-func (f *fakerWithSaverLoader) StateLoad(m state.Source) {}
-
-// fakerOther calls itself .. uh, itself?
-type fakerOther string
-
-func (f *fakerOther) StateTypeName() string {
- return string(*f)
-}
-
-func (f *fakerOther) StateFields() []string {
- return nil
-}
-
-func newFakerOther(name string) *fakerOther {
- f := fakerOther(name)
- return &f
-}
-
-// fakerOtherBadFields returns non-nil fields.
-type fakerOtherBadFields string
-
-func (f *fakerOtherBadFields) StateTypeName() string {
- return string(*f)
-}
-
-func (f *fakerOtherBadFields) StateFields() []string {
- return []string{string(*f)}
-}
-
-func newFakerOtherBadFields(name string) *fakerOtherBadFields {
- f := fakerOtherBadFields(name)
- return &f
-}
-
-// fakerOtherSaverLoader implements SaverLoader methods.
-type fakerOtherSaverLoader string
-
-func (f *fakerOtherSaverLoader) StateTypeName() string {
- return string(*f)
-}
-
-func (f *fakerOtherSaverLoader) StateFields() []string {
- return nil
-}
-
-func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {}
-
-func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {}
-
-func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader {
- f := fakerOtherSaverLoader(name)
- return &f
-}
-
-func TestRegisterPrimitives(t *testing.T) {
- for _, typeName := range []string{
- "int",
- "int8",
- "int16",
- "int32",
- "int64",
- "uint",
- "uintptr",
- "uint8",
- "uint16",
- "uint32",
- "uint64",
- "float32",
- "float64",
- "complex64",
- "complex128",
- "string",
- } {
- t.Run("struct/"+typeName, func(t *testing.T) {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("Registering type %q did not panic", typeName)
- }
- }()
- state.Register(&faker{
- Name: typeName,
- })
- })
- t.Run("other/"+typeName, func(t *testing.T) {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("Registering type %q did not panic", typeName)
- }
- }()
- state.Register(newFakerOther(typeName))
- })
- }
-}
-
-func TestRegisterBad(t *testing.T) {
- const (
- goodName = "foo"
- firstField = "a"
- secondField = "b"
- )
- for name, object := range map[string]state.Type{
- "non-struct-with-fields": newFakerOtherBadFields(goodName),
- "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName),
- "struct-without-saverloader": &faker{Name: goodName},
- "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()),
- "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()),
- "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}},
- "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}},
- "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}},
- "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}},
- "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}},
- "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}},
- } {
- t.Run(name, func(t *testing.T) {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("Registering object %#v did not panic", object)
- }
- }()
- state.Register(object)
- })
-
- }
-}
-
-func TestRegisterTypeOnlyStruct(t *testing.T) {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("Register did not panic")
- }
- }()
- state.Register((*typeOnlyEmptyStruct)(nil))
-}
diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go
deleted file mode 100644
index 44f5a562c..000000000
--- a/pkg/state/tests/string_test.go
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "testing"
-)
-
-const nonEmptyString = "hello world"
-
-var allStrings = []string{
- "",
- nonEmptyString,
- "\\0",
-}
-
-func TestString(t *testing.T) {
- runTestCases(t, false, "plain", flatten(allStrings))
- runTestCases(t, false, "pointers", pointersTo(flatten(allStrings)))
- runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings)))
- runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings))))
-}
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
deleted file mode 100644
index 69143d194..000000000
--- a/pkg/state/tests/struct.go
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-type unregisteredEmptyStruct struct{}
-
-// typeOnlyEmptyStruct just implements the state.Type interface.
-type typeOnlyEmptyStruct struct{}
-
-func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" }
-
-func (*typeOnlyEmptyStruct) StateFields() []string { return nil }
-
-// +stateify savable
-type savableEmptyStruct struct{}
-
-// +stateify savable
-type emptyStructPointer struct {
- nothing *struct{}
-}
-
-// +stateify savable
-type outerSame struct {
- inner inner
-}
-
-// +stateify savable
-type outerFieldFirst struct {
- inner inner
- v int64
-}
-
-// +stateify savable
-type outerFieldSecond struct {
- v int64
- inner inner
-}
-
-// +stateify savable
-type outerArray struct {
- inner [2]inner
-}
-
-// +stateify savable
-type outerSlice struct {
- inner []inner
-}
-
-// +stateify savable
-type inner struct {
- v int64
-}
-
-// +stateify savable
-type outerFieldValue struct {
- inner innerFieldValue
-}
-
-// +stateify savable
-type innerFieldValue struct {
- v int64 `state:".(*savedFieldValue)"`
-}
-
-// +stateify savable
-type savedFieldValue struct {
- v int64
-}
-
-func (ifv *innerFieldValue) saveV() *savedFieldValue {
- return &savedFieldValue{ifv.v}
-}
-
-func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) {
- ifv.v = sfv.v
-}
-
-// +stateify savable
-type system struct {
- v1 interface{}
- v2 interface{}
-}
-
-// +stateify savable
-type system3 struct {
- v1 interface{}
- v2 interface{}
- v3 interface{}
-}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
deleted file mode 100644
index 9826f1ee9..000000000
--- a/pkg/state/tests/struct_test.go
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tests
-
-import (
- "math/rand"
- "testing"
-)
-
-func TestEmptyStruct(t *testing.T) {
- runTestCases(t, false, "plain", []interface{}{
- unregisteredEmptyStruct{},
- typeOnlyEmptyStruct{},
- savableEmptyStruct{},
- })
- runTestCases(t, false, "pointers", pointersTo([]interface{}{
- unregisteredEmptyStruct{},
- typeOnlyEmptyStruct{},
- savableEmptyStruct{},
- }))
- runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{
- // Only registered types can be dispatched via interfaces. All
- // other types should fail, even if it is the empty struct.
- savableEmptyStruct{},
- }))
- runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{
- unregisteredEmptyStruct{},
- typeOnlyEmptyStruct{},
- }))
- runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{
- savableEmptyStruct{},
- })))
- runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{
- unregisteredEmptyStruct{},
- typeOnlyEmptyStruct{},
- })))
-
- // Ensuring empty struct aliasing works.
- es := emptyStructPointer{new(struct{})}
- runTestCases(t, false, "empty-struct-pointers", []interface{}{
- emptyStructPointer{},
- es,
- []emptyStructPointer{es, es}, // Same pointer.
- })
-}
-
-func TestEmbeddedPointers(t *testing.T) {
- // Give each int64 a random value to prevent Go from using
- // runtime.staticuint64s, which confounds tests for struct duplication.
- magic := func() int64 {
- for {
- n := rand.Int63()
- if n < 0 || n > 255 {
- return n
- }
- }
- }
-
- ofs := outerSame{inner{magic()}}
- of1 := outerFieldFirst{inner{magic()}, magic()}
- of2 := outerFieldSecond{magic(), inner{magic()}}
- oa := outerArray{[2]inner{{magic()}, {magic()}}}
- osl := outerSlice{oa.inner[:]}
- ofv := outerFieldValue{innerFieldValue{magic()}}
-
- runTestCases(t, false, "embedded-pointers", []interface{}{
- system{&ofs, &ofs.inner},
- system{&ofs.inner, &ofs},
- system{&of1, &of1.inner},
- system{&of1.inner, &of1},
- system{&of2, &of2.inner},
- system{&of2.inner, &of2},
- system{&oa, &oa.inner[0]},
- system{&oa, &oa.inner[1]},
- system{&oa.inner[0], &oa},
- system{&oa.inner[1], &oa},
- system3{&oa, &oa.inner[0], &oa.inner[1]},
- system3{&oa, &oa.inner[1], &oa.inner[0]},
- system3{&oa.inner[0], &oa, &oa.inner[1]},
- system3{&oa.inner[1], &oa, &oa.inner[0]},
- system3{&oa.inner[0], &oa.inner[1], &oa},
- system3{&oa.inner[1], &oa.inner[0], &oa},
- system{&oa, &osl},
- system{&osl, &oa},
- system{&ofv, &ofv.inner},
- system{&ofv.inner, &ofv},
- })
-}
diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go
deleted file mode 100644
index 435a0e9db..000000000
--- a/pkg/state/tests/tests.go
+++ /dev/null
@@ -1,215 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package tests tests the state packages.
-package tests
-
-import (
- "bytes"
- "context"
- "fmt"
- "math"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/state"
- "gvisor.dev/gvisor/pkg/state/pretty"
-)
-
-// discard is an implementation of wire.Writer.
-type discard struct{}
-
-// Write implements wire.Writer.Write.
-func (discard) Write(p []byte) (int, error) { return len(p), nil }
-
-// WriteByte implements wire.Writer.WriteByte.
-func (discard) WriteByte(byte) error { return nil }
-
-// checkEqual checks if two objects are equal.
-//
-// N.B. This only handles one level of dereferences for NaN. Otherwise we
-// would need to fork the entire implementation of reflect.DeepEqual.
-func checkEqual(root, loadedValue interface{}) bool {
- if reflect.DeepEqual(root, loadedValue) {
- return true
- }
-
- // NaN is not equal to itself. We handle the case of raw floating point
- // primitives here, but don't handle this case nested.
- rf32, ok1 := root.(float32)
- lf32, ok2 := loadedValue.(float32)
- if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
- return true
- }
- rf64, ok1 := root.(float64)
- lf64, ok2 := loadedValue.(float64)
- if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
- return true
- }
-
- // Same real for complex numbers.
- rc64, ok1 := root.(complex64)
- lc64, ok2 := root.(complex64)
- if ok1 && ok2 {
- return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
- }
- rc128, ok1 := root.(complex128)
- lc128, ok2 := root.(complex128)
- if ok1 && ok2 {
- return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
- }
-
- return false
-}
-
-// runTestCases runs a test for each object in objects.
-func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) {
- t.Helper()
- for i, root := range objects {
- t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
- t.Logf("Original object:\n%#v", root)
-
- // Save the passed object.
- saveBuffer := &bytes.Buffer{}
- saveObjectPtr := reflect.New(reflect.TypeOf(root))
- saveObjectPtr.Elem().Set(reflect.ValueOf(root))
- saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
- if err != nil {
- if shouldFail {
- return
- }
- t.Fatalf("Save failed unexpectedly: %v", err)
- }
-
- // Dump the serialized proto to aid with debugging.
- var ppBuf bytes.Buffer
- t.Logf("Raw state:\n%v", saveBuffer.Bytes())
- if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
- // We don't count this as a test failure if we
- // have shouldFail set, but we will count as a
- // failure if we were not expecting to fail.
- if !shouldFail {
- t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
- }
- }
- if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
- // See above.
- if !shouldFail {
- t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
- }
- }
- t.Logf("Encoded state:\n%s", ppBuf.String())
- t.Logf("Save stats:\n%s", saveStats.String())
-
- // Load a new copy of the object.
- loadObjectPtr := reflect.New(reflect.TypeOf(root))
- loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
- if err != nil {
- if shouldFail {
- return
- }
- t.Fatalf("Load failed unexpectedly: %v", err)
- }
-
- // Compare the values.
- loadedValue := loadObjectPtr.Elem().Interface()
- if !checkEqual(root, loadedValue) {
- if shouldFail {
- return
- }
- t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue)
- }
-
- // Everything went okay. Is that good?
- if shouldFail {
- t.Fatalf("This test was expected to fail, but didn't.")
- }
- t.Logf("Load stats:\n%s", loadStats.String())
-
- // Truncate half the bytes in the byte stream,
- // and ensure that we can't restore. Then
- // truncate only the final byte and ensure that
- // we can't restore.
- l := saveBuffer.Len()
- halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
- if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
- t.Errorf("Load with half bytes succeeded unexpectedly.")
- }
- missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
- if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
- t.Errorf("Load with missing byte succeeded unexpectedly.")
- }
- })
- }
-}
-
-// convert converts the slice to an []interface{}.
-func convert(v interface{}) (r []interface{}) {
- s := reflect.ValueOf(v) // Must be slice.
- for i := 0; i < s.Len(); i++ {
- r = append(r, s.Index(i).Interface())
- }
- return r
-}
-
-// flatten flattens multiple slices.
-func flatten(vs ...interface{}) (r []interface{}) {
- for _, v := range vs {
- r = append(r, convert(v)...)
- }
- return r
-}
-
-// filter maps from one slice to another.
-func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) {
- s := reflect.ValueOf(vs)
- for i := 0; i < s.Len(); i++ {
- v, ok := fn(s.Index(i).Interface())
- if ok {
- r = append(r, v)
- }
- }
- return r
-}
-
-// combine combines objects in two slices as specified.
-func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) {
- s1 := reflect.ValueOf(v1)
- s2 := reflect.ValueOf(v2)
- for i := 0; i < s1.Len(); i++ {
- for j := 0; j < s2.Len(); j++ {
- // Combine using the given function.
- r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
- }
- }
- return r
-}
-
-// pointersTo is a filter function that returns pointers.
-func pointersTo(vs interface{}) []interface{} {
- return filter(vs, func(o interface{}) (interface{}, bool) {
- v := reflect.New(reflect.TypeOf(o))
- v.Elem().Set(reflect.ValueOf(o))
- return v.Interface(), true
- })
-}
-
-// interfacesTo is a filter function that returns interface objects.
-func interfacesTo(vs interface{}) []interface{} {
- return filter(vs, func(o interface{}) (interface{}, bool) {
- var v [1]interface{}
- v[0] = o
- return v, true
- })
-}
diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD
deleted file mode 100644
index 311b93dcb..000000000
--- a/pkg/state/wire/BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "wire",
- srcs = ["wire.go"],
- marshal = False,
- stateify = False,
- visibility = ["//:sandbox"],
- deps = ["//pkg/gohacks"],
-)