summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/memmap/mapping_set_impl.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/memmap/mapping_set_impl.go')
-rw-r--r--[-rwxr-xr-x]pkg/sentry/memmap/mapping_set_impl.go377
1 files changed, 373 insertions, 4 deletions
diff --git a/pkg/sentry/memmap/mapping_set_impl.go b/pkg/sentry/memmap/mapping_set_impl.go
index e632f28a5..cb4281950 100755..100644
--- a/pkg/sentry/memmap/mapping_set_impl.go
+++ b/pkg/sentry/memmap/mapping_set_impl.go
@@ -5,6 +5,34 @@ import (
"fmt"
)
+// trackGaps is an optional parameter.
+//
+// If trackGaps is 1, the Set will track maximum gap size recursively,
+// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this
+// case, Key must be an unsigned integer.
+//
+// trackGaps must be 0 or 1.
+const MappingtrackGaps = 0
+
+var _ = uint8(MappingtrackGaps << 7) // Will fail if not zero or one.
+
+// dynamicGap is a type that disappears if trackGaps is 0.
+type MappingdynamicGap [MappingtrackGaps]uint64
+
+// Get returns the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *MappingdynamicGap) Get() uint64 {
+ return d[:][0]
+}
+
+// Set sets the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *MappingdynamicGap) Set(v uint64) {
+ d[:][0] = v
+}
+
const (
// minDegree is the minimum degree of an internal node in a Set B-tree.
//
@@ -263,8 +291,12 @@ func (s *MappingSet) Insert(gap MappingGapIterator, r MappableRange, val Mapping
}
if prev.Ok() && prev.End() == r.Start {
if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ shrinkMaxGap := MappingtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
prev.SetEndUnchecked(r.End)
prev.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
if next.Ok() && next.Start() == r.End {
val = mval
if mval, ok := (mappingSetFunctions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
@@ -278,11 +310,16 @@ func (s *MappingSet) Insert(gap MappingGapIterator, r MappableRange, val Mapping
}
if next.Ok() && next.Start() == r.End {
if mval, ok := (mappingSetFunctions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ shrinkMaxGap := MappingtrackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
next.SetStartUnchecked(r.Start)
next.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return next
}
}
+
return s.InsertWithoutMergingUnchecked(gap, r, val)
}
@@ -309,11 +346,15 @@ func (s *MappingSet) InsertWithoutMerging(gap MappingGapIterator, r MappableRang
// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
func (s *MappingSet) InsertWithoutMergingUnchecked(gap MappingGapIterator, r MappableRange, val MappingsOfRange) MappingIterator {
gap = gap.node.rebalanceBeforeInsert(gap)
+ splitMaxGap := MappingtrackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get())
copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
gap.node.keys[gap.index] = r
gap.node.values[gap.index] = val
gap.node.nrSegments++
+ if splitMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return MappingIterator{gap.node, gap.index}
}
@@ -328,12 +369,20 @@ func (s *MappingSet) Remove(seg MappingIterator) MappingGapIterator {
seg.SetRangeUnchecked(victim.Range())
seg.SetValue(victim.Value())
+
+ nextAdjacentNode := seg.NextSegment().node
+ if MappingtrackGaps != 0 {
+ nextAdjacentNode.updateMaxGapLeaf()
+ }
return s.Remove(victim).NextGap()
}
copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
mappingSetFunctions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
seg.node.nrSegments--
+ if MappingtrackGaps != 0 {
+ seg.node.updateMaxGapLeaf()
+ }
return seg.node.rebalanceAfterRemove(MappingGapIterator{seg.node, seg.index})
}
@@ -383,6 +432,7 @@ func (s *MappingSet) MergeUnchecked(first, second MappingIterator) MappingIterat
first.SetEndUnchecked(second.End())
first.SetValue(mval)
+
return s.Remove(second).PrevSegment()
}
}
@@ -558,6 +608,12 @@ type Mappingnode struct {
// than "isLeaf" because false must be the correct value for an empty root.
hasChildren bool
+ // The longest gap within this node. If the node is a leaf, it's simply the
+ // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys
+ // including the 0th and nrSegments-th gap possibly shared with its upper-level
+ // nodes; if it's a non-leaf node, it's the max of all children's maxGap.
+ maxGap MappingdynamicGap
+
// Nodes store keys and values in separate arrays to maximize locality in
// the common case (scanning keys for lookup).
keys [MappingmaxDegree - 1]MappableRange
@@ -603,12 +659,12 @@ func (n *Mappingnode) nextSibling() *Mappingnode {
// required for insertion, and returns an updated iterator to the position
// represented by gap.
func (n *Mappingnode) rebalanceBeforeInsert(gap MappingGapIterator) MappingGapIterator {
- if n.parent != nil {
- gap = n.parent.rebalanceBeforeInsert(gap)
- }
if n.nrSegments < MappingmaxDegree-1 {
return gap
}
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
if n.parent == nil {
left := &Mappingnode{
@@ -644,6 +700,11 @@ func (n *Mappingnode) rebalanceBeforeInsert(gap MappingGapIterator) MappingGapIt
n.hasChildren = true
n.children[0] = left
n.children[1] = right
+
+ if MappingtrackGaps != 0 {
+ left.updateMaxGapLocal()
+ right.updateMaxGapLocal()
+ }
if gap.node != n {
return gap
}
@@ -681,6 +742,11 @@ func (n *Mappingnode) rebalanceBeforeInsert(gap MappingGapIterator) MappingGapIt
}
n.nrSegments = MappingminDegree - 1
+ if MappingtrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
+
if gap.node != n {
return gap
}
@@ -726,6 +792,11 @@ func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIte
}
n.nrSegments++
sibling.nrSegments--
+
+ if MappingtrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling && gap.index == sibling.nrSegments {
return MappingGapIterator{n, 0}
}
@@ -754,6 +825,11 @@ func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIte
}
n.nrSegments++
sibling.nrSegments--
+
+ if MappingtrackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling {
if gap.index == 0 {
return MappingGapIterator{n, n.nrSegments}
@@ -786,6 +862,7 @@ func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIte
p.children[0] = nil
p.children[1] = nil
}
+
if gap.node == left {
return MappingGapIterator{p, gap.index}
}
@@ -832,10 +909,146 @@ func (n *Mappingnode) rebalanceAfterRemove(gap MappingGapIterator) MappingGapIte
p.children[p.nrSegments] = nil
p.nrSegments--
+ if MappingtrackGaps != 0 {
+ left.updateMaxGapLocal()
+ }
+
n = p
}
}
+// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no
+// necessary update.
+//
+// Preconditions: n must be a leaf node, trackGaps must be 1.
+func (n *Mappingnode) updateMaxGapLeaf() {
+ if n.hasChildren {
+ panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n))
+ }
+ max := n.calculateMaxGapLeaf()
+ if max == n.maxGap.Get() {
+
+ return
+ }
+ oldMax := n.maxGap.Get()
+ n.maxGap.Set(max)
+ if max > oldMax {
+
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() >= max {
+
+ break
+ }
+
+ p.maxGap.Set(max)
+ }
+ return
+ }
+
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() > oldMax {
+
+ break
+ }
+
+ parentNewMax := p.calculateMaxGapInternal()
+ if p.maxGap.Get() == parentNewMax {
+
+ break
+ }
+
+ p.maxGap.Set(parentNewMax)
+ }
+}
+
+// updateMaxGapLocal updates maxGap of the calling node solely with no
+// propagation to ancestor nodes.
+//
+// Precondition: trackGaps must be 1.
+func (n *Mappingnode) updateMaxGapLocal() {
+ if !n.hasChildren {
+
+ n.maxGap.Set(n.calculateMaxGapLeaf())
+ } else {
+
+ n.maxGap.Set(n.calculateMaxGapInternal())
+ }
+}
+
+// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the
+// max.
+//
+// Preconditions: n must be a leaf node.
+func (n *Mappingnode) calculateMaxGapLeaf() uint64 {
+ max := MappingGapIterator{n, 0}.Range().Length()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := (MappingGapIterator{n, i}).Range().Length(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// calculateMaxGapInternal iterates children's maxGap within an internal node n
+// and calculate the max.
+//
+// Preconditions: n must be a non-leaf node.
+func (n *Mappingnode) calculateMaxGapInternal() uint64 {
+ max := n.children[0].maxGap.Get()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := n.children[i].maxGap.Get(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// searchFirstLargeEnoughGap returns the first gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *Mappingnode) searchFirstLargeEnoughGap(minSize uint64) MappingGapIterator {
+ if n.maxGap.Get() < minSize {
+ return MappingGapIterator{}
+ }
+ if n.hasChildren {
+ for i := 0; i <= n.nrSegments; i++ {
+ if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ currentGap := MappingGapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// searchLastLargeEnoughGap returns the last gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *Mappingnode) searchLastLargeEnoughGap(minSize uint64) MappingGapIterator {
+ if n.maxGap.Get() < minSize {
+ return MappingGapIterator{}
+ }
+ if n.hasChildren {
+ for i := n.nrSegments; i >= 0; i-- {
+ if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := n.nrSegments; i >= 0; i-- {
+ currentGap := MappingGapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
// A Iterator is conceptually one of:
//
// - A pointer to a segment in a set; or
@@ -1141,6 +1354,114 @@ func (gap MappingGapIterator) NextGap() MappingGapIterator {
return seg.NextGap()
}
+// NextLargeEnoughGap returns the iterated gap's first next gap with larger
+// length than minSize. If not found, return a terminal gap iterator (does NOT
+// include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap MappingGapIterator) NextLargeEnoughGap(minSize uint64) MappingGapIterator {
+ if MappingtrackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments {
+
+ gap.node = gap.NextSegment().node
+ gap.index = 0
+ return gap.nextLargeEnoughGapHelper(minSize)
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the trailing gap of a non-leaf node.
+func (gap MappingGapIterator) nextLargeEnoughGapHelper(minSize uint64) MappingGapIterator {
+
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+
+ if gap.node == nil {
+ return MappingGapIterator{}
+ }
+
+ gap.index++
+ for gap.index <= gap.node.nrSegments {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index++
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == gap.node.nrSegments {
+
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or
+// equal length than minSize. If not found, return a terminal gap iterator
+// (does NOT include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap MappingGapIterator) PrevLargeEnoughGap(minSize uint64) MappingGapIterator {
+ if MappingtrackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == 0 {
+
+ gap.node = gap.PrevSegment().node
+ gap.index = gap.node.nrSegments
+ return gap.prevLargeEnoughGapHelper(minSize)
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the first gap of a non-leaf node.
+func (gap MappingGapIterator) prevLargeEnoughGapHelper(minSize uint64) MappingGapIterator {
+
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+
+ if gap.node == nil {
+ return MappingGapIterator{}
+ }
+
+ gap.index--
+ for gap.index >= 0 {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index--
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == 0 {
+
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
// segmentBeforePosition returns the predecessor segment of the position given
// by n.children[i], which may or may not contain a child. If no such segment
// exists, segmentBeforePosition returns a terminal iterator.
@@ -1207,7 +1528,15 @@ func (n *Mappingnode) writeDebugString(buf *bytes.Buffer, prefix string) {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
}
buf.WriteString(prefix)
- buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ if n.hasChildren {
+ if MappingtrackGaps != 0 {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get()))
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
}
if child := n.children[n.nrSegments]; child != nil {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
@@ -1259,6 +1588,46 @@ func (s *MappingSet) ImportSortedSlices(sds *MappingSegmentDataSlices) error {
}
return nil
}
+
+// segmentTestCheck returns an error if s is incorrectly sorted, does not
+// contain exactly expectedSegments segments, or contains a segment which
+// fails the passed check.
+//
+// This should be used only for testing, and has been added to this package for
+// templating convenience.
+func (s *MappingSet) segmentTestCheck(expectedSegments int, segFunc func(int, MappableRange, MappingsOfRange) error) error {
+ havePrev := false
+ prev := uint64(0)
+ nrSegments := 0
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ next := seg.Start()
+ if havePrev && prev >= next {
+ return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
+ }
+ if segFunc != nil {
+ if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil {
+ return err
+ }
+ }
+ prev = next
+ havePrev = true
+ nrSegments++
+ }
+ if nrSegments != expectedSegments {
+ return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+ }
+ return nil
+}
+
+// countSegments counts the number of segments in the set.
+//
+// Similar to Check, this should only be used for testing.
+func (s *MappingSet) countSegments() (segments int) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ segments++
+ }
+ return segments
+}
func (s *MappingSet) saveRoot() *MappingSegmentDataSlices {
return s.ExportSortedSlices()
}