summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/mm/pma_set.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/mm/pma_set.go')
-rw-r--r--[-rwxr-xr-x]pkg/sentry/mm/pma_set.go377
1 files changed, 373 insertions, 4 deletions
diff --git a/pkg/sentry/mm/pma_set.go b/pkg/sentry/mm/pma_set.go
index 8906e4edc..d0cc1f9d3 100755..100644
--- a/pkg/sentry/mm/pma_set.go
+++ b/pkg/sentry/mm/pma_set.go
@@ -9,6 +9,34 @@ import (
"fmt"
)
+// trackGaps is an optional parameter.
+//
+// If trackGaps is 1, the Set will track maximum gap size recursively,
+// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this
+// case, Key must be an unsigned integer.
+//
+// trackGaps must be 0 or 1.
+const pmatrackGaps = 0
+
+var _ = uint8(pmatrackGaps << 7) // Will fail if not zero or one.
+
+// dynamicGap is a type that disappears if trackGaps is 0.
+type pmadynamicGap [pmatrackGaps]__generics_imported0.Addr
+
+// Get returns the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *pmadynamicGap) Get() __generics_imported0.Addr {
+ return d[:][0]
+}
+
+// Set sets the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *pmadynamicGap) Set(v __generics_imported0.Addr) {
+ d[:][0] = v
+}
+
const (
// minDegree is the minimum degree of an internal node in a Set B-tree.
//
@@ -267,8 +295,12 @@ func (s *pmaSet) Insert(gap pmaGapIterator, r __generics_imported0.AddrRange, va
}
if prev.Ok() && prev.End() == r.Start {
if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ shrinkMaxGap := pmatrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
prev.SetEndUnchecked(r.End)
prev.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
if next.Ok() && next.Start() == r.End {
val = mval
if mval, ok := (pmaSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
@@ -282,11 +314,16 @@ func (s *pmaSet) Insert(gap pmaGapIterator, r __generics_imported0.AddrRange, va
}
if next.Ok() && next.Start() == r.End {
if mval, ok := (pmaSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ shrinkMaxGap := pmatrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
next.SetStartUnchecked(r.Start)
next.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return next
}
}
+
return s.InsertWithoutMergingUnchecked(gap, r, val)
}
@@ -313,11 +350,15 @@ func (s *pmaSet) InsertWithoutMerging(gap pmaGapIterator, r __generics_imported0
// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
func (s *pmaSet) InsertWithoutMergingUnchecked(gap pmaGapIterator, r __generics_imported0.AddrRange, val pma) pmaIterator {
gap = gap.node.rebalanceBeforeInsert(gap)
+ splitMaxGap := pmatrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get())
copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
gap.node.keys[gap.index] = r
gap.node.values[gap.index] = val
gap.node.nrSegments++
+ if splitMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return pmaIterator{gap.node, gap.index}
}
@@ -332,12 +373,20 @@ func (s *pmaSet) Remove(seg pmaIterator) pmaGapIterator {
seg.SetRangeUnchecked(victim.Range())
seg.SetValue(victim.Value())
+
+ nextAdjacentNode := seg.NextSegment().node
+ if pmatrackGaps != 0 {
+ nextAdjacentNode.updateMaxGapLeaf()
+ }
return s.Remove(victim).NextGap()
}
copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
pmaSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
seg.node.nrSegments--
+ if pmatrackGaps != 0 {
+ seg.node.updateMaxGapLeaf()
+ }
return seg.node.rebalanceAfterRemove(pmaGapIterator{seg.node, seg.index})
}
@@ -387,6 +436,7 @@ func (s *pmaSet) MergeUnchecked(first, second pmaIterator) pmaIterator {
first.SetEndUnchecked(second.End())
first.SetValue(mval)
+
return s.Remove(second).PrevSegment()
}
}
@@ -562,6 +612,12 @@ type pmanode struct {
// than "isLeaf" because false must be the correct value for an empty root.
hasChildren bool
+ // The longest gap within this node. If the node is a leaf, it's simply the
+ // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys
+ // including the 0th and nrSegments-th gap possibly shared with its upper-level
+ // nodes; if it's a non-leaf node, it's the max of all children's maxGap.
+ maxGap pmadynamicGap
+
// Nodes store keys and values in separate arrays to maximize locality in
// the common case (scanning keys for lookup).
keys [pmamaxDegree - 1]__generics_imported0.AddrRange
@@ -607,12 +663,12 @@ func (n *pmanode) nextSibling() *pmanode {
// required for insertion, and returns an updated iterator to the position
// represented by gap.
func (n *pmanode) rebalanceBeforeInsert(gap pmaGapIterator) pmaGapIterator {
- if n.parent != nil {
- gap = n.parent.rebalanceBeforeInsert(gap)
- }
if n.nrSegments < pmamaxDegree-1 {
return gap
}
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
if n.parent == nil {
left := &pmanode{
@@ -648,6 +704,11 @@ func (n *pmanode) rebalanceBeforeInsert(gap pmaGapIterator) pmaGapIterator {
n.hasChildren = true
n.children[0] = left
n.children[1] = right
+
+ if pmatrackGaps != 0 {
+ left.updateMaxGapLocal()
+ right.updateMaxGapLocal()
+ }
if gap.node != n {
return gap
}
@@ -685,6 +746,11 @@ func (n *pmanode) rebalanceBeforeInsert(gap pmaGapIterator) pmaGapIterator {
}
n.nrSegments = pmaminDegree - 1
+ if pmatrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+
if gap.node != n {
return gap
}
@@ -730,6 +796,11 @@ func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator {
}
n.nrSegments++
sibling.nrSegments--
+
+ if pmatrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling && gap.index == sibling.nrSegments {
return pmaGapIterator{n, 0}
}
@@ -758,6 +829,11 @@ func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator {
}
n.nrSegments++
sibling.nrSegments--
+
+ if pmatrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling {
if gap.index == 0 {
return pmaGapIterator{n, n.nrSegments}
@@ -790,6 +866,7 @@ func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator {
p.children[0] = nil
p.children[1] = nil
}
+
if gap.node == left {
return pmaGapIterator{p, gap.index}
}
@@ -836,10 +913,146 @@ func (n *pmanode) rebalanceAfterRemove(gap pmaGapIterator) pmaGapIterator {
p.children[p.nrSegments] = nil
p.nrSegments--
+ if pmatrackGaps != 0 {
+ left.updateMaxGapLocal()
+ }
+
n = p
}
}
+// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no
+// necessary update.
+//
+// Preconditions: n must be a leaf node, trackGaps must be 1.
+func (n *pmanode) updateMaxGapLeaf() {
+ if n.hasChildren {
+ panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n))
+ }
+ max := n.calculateMaxGapLeaf()
+ if max == n.maxGap.Get() {
+
+ return
+ }
+ oldMax := n.maxGap.Get()
+ n.maxGap.Set(max)
+ if max > oldMax {
+
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() >= max {
+
+ break
+ }
+
+ p.maxGap.Set(max)
+ }
+ return
+ }
+
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() > oldMax {
+
+ break
+ }
+
+ parentNewMax := p.calculateMaxGapInternal()
+ if p.maxGap.Get() == parentNewMax {
+
+ break
+ }
+
+ p.maxGap.Set(parentNewMax)
+ }
+}
+
+// updateMaxGapLocal updates maxGap of the calling node solely with no
+// propagation to ancestor nodes.
+//
+// Precondition: trackGaps must be 1.
+func (n *pmanode) updateMaxGapLocal() {
+ if !n.hasChildren {
+
+ n.maxGap.Set(n.calculateMaxGapLeaf())
+ } else {
+
+ n.maxGap.Set(n.calculateMaxGapInternal())
+ }
+}
+
+// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the
+// max.
+//
+// Preconditions: n must be a leaf node.
+func (n *pmanode) calculateMaxGapLeaf() __generics_imported0.Addr {
+ max := pmaGapIterator{n, 0}.Range().Length()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := (pmaGapIterator{n, i}).Range().Length(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// calculateMaxGapInternal iterates children's maxGap within an internal node n
+// and calculate the max.
+//
+// Preconditions: n must be a non-leaf node.
+func (n *pmanode) calculateMaxGapInternal() __generics_imported0.Addr {
+ max := n.children[0].maxGap.Get()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := n.children[i].maxGap.Get(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// searchFirstLargeEnoughGap returns the first gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *pmanode) searchFirstLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator {
+ if n.maxGap.Get() < minSize {
+ return pmaGapIterator{}
+ }
+ if n.hasChildren {
+ for i := 0; i <= n.nrSegments; i++ {
+ if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ currentGap := pmaGapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// searchLastLargeEnoughGap returns the last gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *pmanode) searchLastLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator {
+ if n.maxGap.Get() < minSize {
+ return pmaGapIterator{}
+ }
+ if n.hasChildren {
+ for i := n.nrSegments; i >= 0; i-- {
+ if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := n.nrSegments; i >= 0; i-- {
+ currentGap := pmaGapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
// A Iterator is conceptually one of:
//
// - A pointer to a segment in a set; or
@@ -1145,6 +1358,114 @@ func (gap pmaGapIterator) NextGap() pmaGapIterator {
return seg.NextGap()
}
+// NextLargeEnoughGap returns the iterated gap's first next gap with larger
+// length than minSize. If not found, return a terminal gap iterator (does NOT
+// include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap pmaGapIterator) NextLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator {
+ if pmatrackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments {
+
+ gap.node = gap.NextSegment().node
+ gap.index = 0
+ return gap.nextLargeEnoughGapHelper(minSize)
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the trailing gap of a non-leaf node.
+func (gap pmaGapIterator) nextLargeEnoughGapHelper(minSize __generics_imported0.Addr) pmaGapIterator {
+
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+
+ if gap.node == nil {
+ return pmaGapIterator{}
+ }
+
+ gap.index++
+ for gap.index <= gap.node.nrSegments {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index++
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == gap.node.nrSegments {
+
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or
+// equal length than minSize. If not found, return a terminal gap iterator
+// (does NOT include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap pmaGapIterator) PrevLargeEnoughGap(minSize __generics_imported0.Addr) pmaGapIterator {
+ if pmatrackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == 0 {
+
+ gap.node = gap.PrevSegment().node
+ gap.index = gap.node.nrSegments
+ return gap.prevLargeEnoughGapHelper(minSize)
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the first gap of a non-leaf node.
+func (gap pmaGapIterator) prevLargeEnoughGapHelper(minSize __generics_imported0.Addr) pmaGapIterator {
+
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+
+ if gap.node == nil {
+ return pmaGapIterator{}
+ }
+
+ gap.index--
+ for gap.index >= 0 {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index--
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == 0 {
+
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
// segmentBeforePosition returns the predecessor segment of the position given
// by n.children[i], which may or may not contain a child. If no such segment
// exists, segmentBeforePosition returns a terminal iterator.
@@ -1211,7 +1532,15 @@ func (n *pmanode) writeDebugString(buf *bytes.Buffer, prefix string) {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
}
buf.WriteString(prefix)
- buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ if n.hasChildren {
+ if pmatrackGaps != 0 {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get()))
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
}
if child := n.children[n.nrSegments]; child != nil {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
@@ -1263,6 +1592,46 @@ func (s *pmaSet) ImportSortedSlices(sds *pmaSegmentDataSlices) error {
}
return nil
}
+
+// segmentTestCheck returns an error if s is incorrectly sorted, does not
+// contain exactly expectedSegments segments, or contains a segment which
+// fails the passed check.
+//
+// This should be used only for testing, and has been added to this package for
+// templating convenience.
+func (s *pmaSet) segmentTestCheck(expectedSegments int, segFunc func(int, __generics_imported0.AddrRange, pma) error) error {
+ havePrev := false
+ prev := __generics_imported0.Addr(0)
+ nrSegments := 0
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ next := seg.Start()
+ if havePrev && prev >= next {
+ return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
+ }
+ if segFunc != nil {
+ if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil {
+ return err
+ }
+ }
+ prev = next
+ havePrev = true
+ nrSegments++
+ }
+ if nrSegments != expectedSegments {
+ return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+ }
+ return nil
+}
+
+// countSegments counts the number of segments in the set.
+//
+// Similar to Check, this should only be used for testing.
+func (s *pmaSet) countSegments() (segments int) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ segments++
+ }
+ return segments
+}
func (s *pmaSet) saveRoot() *pmaSegmentDataSlices {
return s.ExportSortedSlices()
}