diff options
Diffstat (limited to 'pkg')
60 files changed, 2263 insertions, 620 deletions
diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index ea8d2422c..7a82631c5 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "goid.go", "goid_amd64.s", + "goid_arm64.s", "goid_race.go", "goid_unsafe.go", ], diff --git a/pkg/sentry/fsimpl/gofer/pagemath.go b/pkg/goid/goid_arm64.s index 847cb0784..a7465b75d 100644 --- a/pkg/sentry/fsimpl/gofer/pagemath.go +++ b/pkg/goid/goid_arm64.s @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,20 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package gofer +#include "textflag.h" -import ( - "gvisor.dev/gvisor/pkg/usermem" -) - -// This are equivalent to usermem.Addr.RoundDown/Up, but without the -// potentially truncating conversion to usermem.Addr. This is necessary because -// there is no way to define generic "PageRoundDown/Up" functions in Go. - -func pageRoundDown(x uint64) uint64 { - return x &^ (usermem.PageSize - 1) -} - -func pageRoundUp(x uint64) uint64 { - return pageRoundDown(x + usermem.PageSize - 1) -} +// func getg() *g +TEXT ·getg(SB),NOSPLIT,$0-8 + MOVD g, R0 // g + MOVD R0, ret+0(FP) + RET diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index 41bf104d0..f84d03700 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -5,6 +5,8 @@ package(licenses = ["notice"]) go_library( name = "linewriter", srcs = ["linewriter.go"], + marshal = False, + stateify = False, visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index a7c8f7bef..3ed6aba5c 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -10,6 +10,8 @@ go_library( "json_k8s.go", "log.go", ], + marshal = False, + stateify = False, visibility = [ "//visibility:public", ], diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD index 1b487b887..f57ccc170 100644 --- a/pkg/segment/BUILD +++ b/pkg/segment/BUILD @@ -21,6 +21,8 @@ go_template( ], opt_consts = [ "minDegree", + # trackGaps must either be 0 or 1. + "trackGaps", ], types = [ "Key", diff --git a/pkg/segment/set.go b/pkg/segment/set.go index 03e4f258f..1a17ad9cb 100644 --- a/pkg/segment/set.go +++ b/pkg/segment/set.go @@ -36,6 +36,34 @@ type Range interface{} // Value is a required type parameter. type Value interface{} +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const trackGaps = 0 + +var _ = uint8(trackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type dynamicGap [trackGaps]Key + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Get() Key { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Set(v Key) { + d[:][0] = v +} + // Functions is a required type parameter that must be a struct implementing // the methods defined by Functions. type Functions interface { @@ -327,8 +355,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if prev.Ok() && prev.End() == r.Start { if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() prev.SetEndUnchecked(r.End) prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } if next.Ok() && next.Start() == r.End { val = mval if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { @@ -342,11 +374,16 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if next.Ok() && next.Start() == r.End { if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() next.SetStartUnchecked(r.Start) next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } return next } } + // InsertWithoutMergingUnchecked will maintain maxGap if necessary. return s.InsertWithoutMergingUnchecked(gap, r, val) } @@ -373,11 +410,15 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) gap.node.keys[gap.index] = r gap.node.values[gap.index] = val gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } return Iterator{gap.node, gap.index} } @@ -399,12 +440,23 @@ func (s *Set) Remove(seg Iterator) GapIterator { // overlap. seg.SetRangeUnchecked(victim.Range()) seg.SetValue(victim.Value()) + // Need to update the nextAdjacentNode's maxGap because the gap in between + // must have been modified by updating seg.Range() to victim.Range(). + // seg.NextSegment() must exist since the last segment can't be in a + // non-leaf node. + nextAdjacentNode := seg.NextSegment().node + if trackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } return s.Remove(victim).NextGap() } copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) seg.node.nrSegments-- + if trackGaps != 0 { + seg.node.updateMaxGapLeaf() + } return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index}) } @@ -455,6 +507,7 @@ func (s *Set) MergeUnchecked(first, second Iterator) Iterator { // overlaps second. first.SetEndUnchecked(second.End()) first.SetValue(mval) + // Remove will handle the maxGap update if necessary. return s.Remove(second).PrevSegment() } } @@ -631,6 +684,12 @@ type node struct { // than "isLeaf" because false must be the correct value for an empty root. hasChildren bool + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap dynamicGap + // Nodes store keys and values in separate arrays to maximize locality in // the common case (scanning keys for lookup). keys [maxDegree - 1]Range @@ -676,12 +735,12 @@ func (n *node) nextSibling() *node { // required for insertion, and returns an updated iterator to the position // represented by gap. func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { - if n.parent != nil { - gap = n.parent.rebalanceBeforeInsert(gap) - } if n.nrSegments < maxDegree-1 { return gap } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } if n.parent == nil { // n is root. Move all segments before and after n's median segment // into new child nodes adjacent to the median segment, which is now @@ -719,6 +778,13 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { n.hasChildren = true n.children[0] = left n.children[1] = right + // In this case, n's maxGap won't violated as it's still the root, + // but the left and right children should be updated locally as they + // are newly split from n. + if trackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } if gap.node != n { return gap } @@ -758,6 +824,12 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { } } n.nrSegments = minDegree - 1 + // MaxGap of n's parent is not violated because the segments within is not changed. + // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } // gap.node can't be n.parent because gaps are always in leaf nodes. if gap.node != n { return gap @@ -821,6 +893,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling && gap.index == sibling.nrSegments { return GapIterator{n, 0} } @@ -849,6 +927,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling { if gap.index == 0 { return GapIterator{n, n.nrSegments} @@ -886,6 +970,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { p.children[0] = nil p.children[1] = nil } + // No need to update maxGap of p as its content is not changed. if gap.node == left { return GapIterator{p, gap.index} } @@ -932,11 +1017,152 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } p.children[p.nrSegments] = nil p.nrSegments-- + // Update maxGap of left locally, no need to change p and right because + // p's contents is not changed and right is already invalid. + if trackGaps != 0 { + left.updateMaxGapLocal() + } // This process robs p of one segment, so recurse into rebalancing p. n = p } } +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *node) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + // If new max equals the old maxGap, no update is needed. + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + // Grow ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + // p and its ancestors already contain an equal or larger gap. + break + } + // Only if new maxGap is larger than parent's + // old maxGap, propagate this update to parent. + p.maxGap.Set(max) + } + return + } + // Shrink ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + // p and its ancestors still contain a larger gap. + break + } + // If new max is smaller than the old maxGap, and this gap used + // to be the maxGap of its parent, iterate parent's children + // and calculate parent's new maxGap.(It's probable that parent + // has two children with the old maxGap, but we need to check it anyway.) + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + // p and its ancestors still contain a gap of at least equal size. + break + } + // If p's new maxGap differs from the old one, propagate this update. + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *node) updateMaxGapLocal() { + if !n.hasChildren { + // Leaf node iterates its gaps. + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + // Non-leaf node iterates its children. + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *node) calculateMaxGapLeaf() Key { + max := GapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (GapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *node) calculateMaxGapInternal() Key { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + // A Iterator is conceptually one of: // // - A pointer to a segment in a set; or @@ -1243,6 +1469,122 @@ func (gap GapIterator) NextGap() GapIterator { return seg.NextGap() } +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the trailing one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate subsequent gaps. + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + // If gap is the first gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the first one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate previous gaps. + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + // If gap is the first gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + // segmentBeforePosition returns the predecessor segment of the position given // by n.children[i], which may or may not contain a child. If no such segment // exists, segmentBeforePosition returns a terminal iterator. @@ -1271,7 +1613,7 @@ func segmentAfterPosition(n *node, i int) Iterator { func zeroValueSlice(slice []Value) { // TODO(jamieliu): check if Go is actually smart enough to optimize a - // ClearValue that assigns nil to a memset here + // ClearValue that assigns nil to a memset here. for i := range slice { Functions{}.ClearValue(&slice[i]) } @@ -1310,7 +1652,15 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) } buf.WriteString(prefix) - buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + if n.hasChildren { + if trackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } } if child := n.children[n.nrSegments]; child != nil { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) @@ -1362,3 +1712,43 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { } return nil } + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error { + havePrev := false + prev := Key(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *Set) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index f2d8462d8..131bf09b9 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -29,10 +29,28 @@ go_template_instance( }, ) +go_template_instance( + name = "gap_set", + out = "gap_set.go", + consts = { + "trackGaps": "1", + }, + package = "segment", + prefix = "gap", + template = "//pkg/segment:generic_set", + types = { + "Key": "int", + "Range": "Range", + "Value": "int", + "Functions": "gapSetFunctions", + }, +) + go_library( name = "segment", testonly = 1, srcs = [ + "gap_set.go", "int_range.go", "int_set.go", "set_functions.go", diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go index 97b16c158..85fa19096 100644 --- a/pkg/segment/test/segment_test.go +++ b/pkg/segment/test/segment_test.go @@ -17,6 +17,7 @@ package segment import ( "fmt" "math/rand" + "reflect" "testing" ) @@ -32,61 +33,65 @@ const ( // valueOffset is the difference between the value and start of test // segments. valueOffset = 100000 + + // intervalLength is the interval used by random gap tests. + intervalLength = 10 ) func shuffle(xs []int) { - for i := range xs { - j := rand.Intn(i + 1) - xs[i], xs[j] = xs[j], xs[i] - } + rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] }) } -func randPermutation(size int) []int { +func randIntervalPermutation(size int) []int { p := make([]int, size) for i := range p { - p[i] = i + p[i] = intervalLength * i } shuffle(p) return p } -// checkSet returns an error if s is incorrectly sorted, does not contain -// exactly expectedSegments segments, or contains a segment for which val != -// key + valueOffset. -func checkSet(s *Set, expectedSegments int) error { - havePrev := false - prev := 0 - nrSegments := 0 - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - next := seg.Start() - if havePrev && prev >= next { - return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) - } - if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want) - } - prev = next - havePrev = true - nrSegments++ - } - if nrSegments != expectedSegments { - return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) +// validate can be passed to Check. +func validate(nr int, r Range, v int) error { + if got, want := v, r.Start+valueOffset; got != want { + return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want) } return nil } -// countSegmentsIn returns the number of segments in s. -func countSegmentsIn(s *Set) int { - var count int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ +// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well +// maintained. +func checkSetMaxGap(s *gapSet) error { + n := s.root + return checkNodeMaxGap(&n) +} + +// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is +// not well maintained. +func checkNodeMaxGap(n *gapnode) error { + var max int + if !n.hasChildren { + max = n.calculateMaxGapLeaf() + } else { + for i := 0; i <= n.nrSegments; i++ { + child := n.children[i] + if err := checkNodeMaxGap(child); err != nil { + return err + } + if temp := child.maxGap.Get(); i == 0 || temp > max { + max = temp + } + } + } + if max != n.maxGap.Get() { + return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap) } - return count + return nil } func TestAddRandom(t *testing.T) { var s Set - order := randPermutation(testSize) + order := rand.Perm(testSize) var nrInsertions int for i, j := range order { if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { @@ -94,12 +99,12 @@ func TestAddRandom(t *testing.T) { break } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -115,7 +120,156 @@ func TestRemoveRandom(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(testSize) + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapAddRandom(t *testing.T) { + var s gapSet + order := rand.Perm(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithRandomInterval(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + nrInsertions := 1 + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapRemoveRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapRemoveHalfRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := randIntervalPermutation(testSize) + order = order[:testSize/2] var nrRemovals int for i, j := range order { seg := s.FindSegment(j) @@ -123,14 +277,19 @@ func TestRemoveRandom(t *testing.T) { t.Errorf("Iteration %d: failed to find segment with key %d", i, j) break } + temprange := seg.Range() s.Remove(seg) nrRemovals++ - if err := checkSet(&s, testSize-nrRemovals); err != nil { + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } } - if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want { + if got, want := s.countSegments(), testSize-nrRemovals; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -140,6 +299,148 @@ func TestRemoveRandom(t *testing.T) { } } +func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + var nrRemovals int + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestNextLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestPrevLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + end := s.LastSegment().End() + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + func TestAddSequentialAdjacent(t *testing.T) { var s Set var nrInsertions int @@ -148,12 +449,12 @@ func TestAddSequentialAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -202,12 +503,12 @@ func TestAddSequentialNonAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -293,7 +594,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -351,7 +652,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -378,7 +679,7 @@ func benchmarkAddSequential(b *testing.B, size int) { } func benchmarkAddRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -416,7 +717,7 @@ func benchmarkFindRandom(b *testing.B, size int) { b.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -470,7 +771,7 @@ func benchmarkAddFindRemoveSequential(b *testing.B, size int) { } func benchmarkAddFindRemoveRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go index bcddb39bb..7cd895cc7 100644 --- a/pkg/segment/test/set_functions.go +++ b/pkg/segment/test/set_functions.go @@ -14,21 +14,16 @@ package segment -// Basic numeric constants that we define because the math package doesn't. -// TODO(nlacasse): These should be Math.MaxInt64/MinInt64? -const ( - maxInt = int(^uint(0) >> 1) - minInt = -maxInt - 1 -) - type setFunctions struct{} -func (setFunctions) MinKey() int { - return minInt +// MinKey returns the minimum key for the set. +func (s setFunctions) MinKey() int { + return -s.MaxKey() - 1 } +// MaxKey returns the maximum key for the set. func (setFunctions) MaxKey() int { - return maxInt + return int(^uint(0) >> 1) } func (setFunctions) ClearValue(*int) {} @@ -40,3 +35,20 @@ func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) { func (setFunctions) Split(_ Range, val int, _ int) (int, int) { return val, val } + +type gapSetFunctions struct { + setFunctions +} + +// MinKey is adjusted to make sure no add overflow would happen in test cases. +// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length(). +// +// Normally Keys should be unsigned to avoid these issues. +func (s gapSetFunctions) MinKey() int { + return s.setFunctions.MinKey() / 2 +} + +// MaxKey returns the maximum key for the set. +func (s gapSetFunctions) MaxKey() int { + return s.setFunctions.MaxKey() / 2 +} diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go index 92d062513..95dfd1e90 100644 --- a/pkg/sentry/arch/syscalls_arm64.go +++ b/pkg/sentry/arch/syscalls_arm64.go @@ -23,7 +23,7 @@ const restartSyscallNr = uintptr(128) // // In linux, at the entry of the syscall handler(el0_svc_common()), value of R0 // is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0 -// was not accessible to the user space application, so we have to do the same +// was not accessible to the userspace application, so we have to do the same // operation in the sentry code to save the R0 value into the App context. func (c *context64) SyscallSaveOrig() { c.OrigR0 = c.Regs.Regs[0] diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index 6564fd0c6..dd6f5aba6 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/usage" ) // FrameRefSetFunctions implements segment.Functions for FrameRefSet. @@ -49,3 +50,42 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } + +// IncRefAndAccount adds a reference on the range fr. All newly inserted segments +// are accounted as host page cache memory mappings. +func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { + seg, gap := refs.Find(fr.Start) + for { + switch { + case seg.Ok() && seg.Start() < fr.End: + seg = refs.Isolate(seg, fr) + seg.SetValue(seg.Value() + 1) + seg, gap = seg.NextNonEmpty() + case gap.Ok() && gap.Start() < fr.End: + newRange := gap.Range().Intersect(fr) + usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) + seg, gap = refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() + default: + refs.MergeAdjacent(fr) + return + } + } +} + +// DecRefAndAccount removes a reference on the range fr and untracks segments +// that are removed from memory accounting. +func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { + seg := refs.FindSegment(fr.Start) + + for seg.Ok() && seg.Start() < fr.End { + seg = refs.Isolate(seg, fr) + if old := seg.Value(); old == 1 { + usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) + seg = refs.Remove(seg).NextSegment() + } else { + seg.SetValue(old - 1) + seg = seg.NextSegment() + } + } + refs.MergeAdjacent(fr) +} diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md index c3988aa43..635cc009b 100644 --- a/pkg/sentry/fs/g3doc/fuse.md +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -7,20 +7,20 @@ currently incomplete and the document will be updated as things progress. # FUSE: Filesystem in Userspace -The sentry supports dispatching filesystem operations to a FUSE server, -allowing FUSE filesystem to be used with a sandbox. +The sentry supports dispatching filesystem operations to a FUSE server, allowing +FUSE filesystem to be used with a sandbox. ## Overview FUSE has two main components: -1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards - filesystem operations (usually initiated by syscalls) to the server. +1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards + filesystem operations (usually initiated by syscalls) to the server. -2. A server, which is a userspace daemon that implements the actual filesystem. +2. A server, which is a userspace daemon that implements the actual filesystem. -The sentry implements the client component, which allows a server daemon -running within the sandbox to implement a filesystem within the sandbox. +The sentry implements the client component, which allows a server daemon running +within the sandbox to implement a filesystem within the sandbox. A FUSE filesystem is initialized with `mount(2)`, typically with the help of a utility like `fusermount(1)`. Various mount options exist for establishing @@ -30,43 +30,43 @@ and server. The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation, the client and server use the FUSE protocol described in `fuse(4)` to service -filesystem operations. See the "Protocol" section below for more -information about this protocol. The core of the sentry support for FUSE is the -client-side implementation of this protocol. +filesystem operations. See the "Protocol" section below for more information +about this protocol. The core of the sentry support for FUSE is the client-side +implementation of this protocol. ## FUSE in the Sentry The sentry's FUSE client targets VFS2 and has the following components: -- An implementation of `/dev/fuse`. +- An implementation of `/dev/fuse`. -- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting - VFS2, one point of contention may be the lack of inodes in VFS2. We can - tentatively implement a kernfs-based filesystem to bridge the gap in APIs. The - kernfs base functionality can serve the role of the Linux inode cache and, the - filesystem can map VFS2 syscalls to kernfs inode operations; see the - `kernfs.Inode` interface. +- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting + VFS2, one point of contention may be the lack of inodes in VFS2. We can + tentatively implement a kernfs-based filesystem to bridge the gap in APIs. + The kernfs base functionality can serve the role of the Linux inode cache + and, the filesystem can map VFS2 syscalls to kernfs inode operations; see + the `kernfs.Inode` interface. -The FUSE protocol lends itself well to marshaling with `go_marshal`. The -various request and response packets can be defined in the ABI package and -converted to and from the wire format using `go_marshal`. +The FUSE protocol lends itself well to marshaling with `go_marshal`. The various +request and response packets can be defined in the ABI package and converted to +and from the wire format using `go_marshal`. ### Design Goals -- While filesystem performance is always important, the sentry's FUSE support is - primarily concerned with compatibility, with performance as a secondary - concern. +- While filesystem performance is always important, the sentry's FUSE support + is primarily concerned with compatibility, with performance as a secondary + concern. -- Avoiding deadlocks from a hung server daemon. +- Avoiding deadlocks from a hung server daemon. -- Consider the potential for denial of service from a malicious server - daemon. Protecting itself from userspace is already a design goal for the - sentry, but needs additional consideration for FUSE. Normally, an operating - system doesn't rely on userspace to make progress with filesystem - operations. Since this changes with FUSE, it opens up the possibility of - creating a chain of dependencies controlled by userspace, which could affect - an entire sandbox. For example: a FUSE op can block a syscall, which could be - holding a subsystem lock, which can then block another task goroutine. +- Consider the potential for denial of service from a malicious server daemon. + Protecting itself from userspace is already a design goal for the sentry, + but needs additional consideration for FUSE. Normally, an operating system + doesn't rely on userspace to make progress with filesystem operations. Since + this changes with FUSE, it opens up the possibility of creating a chain of + dependencies controlled by userspace, which could affect an entire sandbox. + For example: a FUSE op can block a syscall, which could be holding a + subsystem lock, which can then block another task goroutine. ### Milestones @@ -76,23 +76,23 @@ ops can be implemented in parallel. #### Minimal client that can mount a trivial FUSE filesystem. -- Implement `/dev/fuse`. +- Implement `/dev/fuse`. -- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. +- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. #### Read-only mount with basic file operations -- Implement the majority of file, directory and file descriptor FUSE ops. For - this milestone, we can skip uncommon or complex operations like mmap, mknod, - file locking, poll, and extended attributes. We can stub these out along with - any ops that modify the filesystem. The exact list of required ops are to be - determined, but the goal is to mount a real filesystem as read-only, and be - able to read contents from the filesystem in the sentry. +- Implement the majority of file, directory and file descriptor FUSE ops. For + this milestone, we can skip uncommon or complex operations like mmap, mknod, + file locking, poll, and extended attributes. We can stub these out along + with any ops that modify the filesystem. The exact list of required ops are + to be determined, but the goal is to mount a real filesystem as read-only, + and be able to read contents from the filesystem in the sentry. #### Full read-write support -- Implement the remaining FUSE ops and decide if we can omit rarely used - operations like ioctl. +- Implement the remaining FUSE ops and decide if we can omit rarely used + operations like ioctl. # Appendix @@ -145,19 +145,19 @@ operations map to the sentry virtual filesystem. These operations are specific to FUSE and don't have a corresponding action in a generic filesystem. -- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the - first message sent by the client after mount. This is used for version and - feature negotiation. This is related to `mount(2)`. -- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`. -- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the - `fuse_in_header.unique` value provided in the corresponding request - header. The client can send at most one of these per request, and will enter - an uninterruptible wait for a reply. The server is expected to reply promptly. -- `FUSE_FORGET`: A hint to the server that server should evict the indicate node - from any caches. This is wired up to `(struct super_operations).evict_inode` - in Linux, which is in turned hooked as the inode cache shrinker which is - typically triggered by system memory pressure. -- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`. +- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the + first message sent by the client after mount. This is used for version and + feature negotiation. This is related to `mount(2)`. +- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`. +- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the + `fuse_in_header.unique` value provided in the corresponding request header. + The client can send at most one of these per request, and will enter an + uninterruptible wait for a reply. The server is expected to reply promptly. +- `FUSE_FORGET`: A hint to the server that server should evict the indicate + node from any caches. This is wired up to `(struct + super_operations).evict_inode` in Linux, which is in turned hooked as the + inode cache shrinker which is typically triggered by system memory pressure. +- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`. #### Filesystem Syscalls @@ -167,92 +167,94 @@ otherwise noted. Node creation: -- `FUSE_MKNOD` -- `FUSE_MKDIR` -- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which - atomically creates and opens a node. +- `FUSE_MKNOD` +- `FUSE_MKDIR` +- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which + atomically creates and opens a node. Node attributes and extended attributes: -- `FUSE_GETATTR` -- `FUSE_SETATTR` -- `FUSE_SETXATTR` -- `FUSE_GETXATTR` -- `FUSE_LISTXATTR` -- `FUSE_REMOVEXATTR` +- `FUSE_GETATTR` +- `FUSE_SETATTR` +- `FUSE_SETXATTR` +- `FUSE_GETXATTR` +- `FUSE_LISTXATTR` +- `FUSE_REMOVEXATTR` Node link manipulation: -- `FUSE_READLINK` -- `FUSE_LINK` -- `FUSE_SYMLINK` -- `FUSE_UNLINK` +- `FUSE_READLINK` +- `FUSE_LINK` +- `FUSE_SYMLINK` +- `FUSE_UNLINK` Directory operations: -- `FUSE_RMDIR` -- `FUSE_RENAME` -- `FUSE_RENAME2` -- `FUSE_OPENDIR`: `open(2)` for directories. -- `FUSE_RELEASEDIR`: `close(2)` for directories. -- `FUSE_READDIR` -- `FUSE_READDIRPLUS` -- `FUSE_FSYNCDIR`: `fsync(2)` for directories. -- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is - reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path - component to a node. However the returned identifier is opaque to the - client. The server must remember this mapping, as this is how the client will - reference the node in the future. +- `FUSE_RMDIR` +- `FUSE_RENAME` +- `FUSE_RENAME2` +- `FUSE_OPENDIR`: `open(2)` for directories. +- `FUSE_RELEASEDIR`: `close(2)` for directories. +- `FUSE_READDIR` +- `FUSE_READDIRPLUS` +- `FUSE_FSYNCDIR`: `fsync(2)` for directories. +- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is + reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path + component to a node. However the returned identifier is opaque to the + client. The server must remember this mapping, as this is how the client + will reference the node in the future. File operations: -- `FUSE_OPEN`: `open(2)` for files. -- `FUSE_RELEASE`: `close(2)` for files. -- `FUSE_FSYNC` -- `FUSE_FALLOCATE` -- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`. -- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`. +- `FUSE_OPEN`: `open(2)` for files. +- `FUSE_RELEASE`: `close(2)` for files. +- `FUSE_FSYNC` +- `FUSE_FALLOCATE` +- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`. +- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`. File locking: -- `FUSE_GETLK` -- `FUSE_SETLK` -- `FUSE_SETLKW` -- `FUSE_COPY_FILE_RANGE` +- `FUSE_GETLK` +- `FUSE_SETLK` +- `FUSE_SETLKW` +- `FUSE_COPY_FILE_RANGE` File descriptor operations: -- `FUSE_IOCTL` -- `FUSE_POLL` -- `FUSE_LSEEK` +- `FUSE_IOCTL` +- `FUSE_POLL` +- `FUSE_LSEEK` Filesystem operations: -- `FUSE_STATFS` +- `FUSE_STATFS` #### Permissions -- `FUSE_ACCESS` is used to check if a node is accessible, as part of many - syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` - in the sentry. +- `FUSE_ACCESS` is used to check if a node is accessible, as part of many + syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the + sentry. #### I/O Operations These ops are used to read and write file pages. They're used to implement both I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`. -- `FUSE_READ` -- `FUSE_WRITE` +- `FUSE_READ` +- `FUSE_WRITE` #### Miscellaneous -- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is - closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)` syscall - from the user. Maps to `vfs.FileDescriptorImpl.Release` in the sentry. -- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed. -- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?] +- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is + closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)` + syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the + sentry. +- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed. +- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?] # References -- `fuse(4)` manpage. -- Linux kernel FUSE documentation: https://www.kernel.org/doc/html/latest/filesystems/fuse.html +- `fuse(4)` manpage. +- Linux kernel FUSE documentation: + https://www.kernel.org/doc/html/latest/filesystems/fuse.html diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index e201801d6..f7bc325d1 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -27,8 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - const ( // canonMaxBytes is the number of bytes that fit into a single line of // terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE @@ -445,5 +443,3 @@ func (l *lineDiscipline) peek(b []byte) int { } return size } - -// LINT.ThenChange(../../fs/tty/line_discipline.go) diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 04a292927..7a7ce5d81 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -27,8 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - // masterInode is the inode for the master end of the Terminal. type masterInode struct { kernfs.InodeAttrs @@ -222,5 +220,3 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { unimpl.EmitUnimplementedEvent(ctx) } } - -// LINT.ThenChange(../../fs/tty/master.go) diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go index 29a6be858..dffb4232c 100644 --- a/pkg/sentry/fsimpl/devpts/queue.go +++ b/pkg/sentry/fsimpl/devpts/queue.go @@ -25,8 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - // waitBufMaxBytes is the maximum size of a wait buffer. It is based on // TTYB_DEFAULT_MEM_LIMIT. const waitBufMaxBytes = 131072 @@ -236,5 +234,3 @@ func (q *queue) waitBufAppend(b []byte) { q.waitBuf = append(q.waitBuf, b) q.waitBufLen += uint64(len(b)) } - -// LINT.ThenChange(../../fs/tty/queue.go) diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index 0a98dc896..526cd406c 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -26,8 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - // slaveInode is the inode for the slave end of the Terminal. type slaveInode struct { kernfs.InodeAttrs @@ -182,5 +180,3 @@ func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() return sfd.inode.Stat(fs, opts) } - -// LINT.ThenChange(../../fs/tty/slave.go) diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go index b44e673d8..7d2781c54 100644 --- a/pkg/sentry/fsimpl/devpts/terminal.go +++ b/pkg/sentry/fsimpl/devpts/terminal.go @@ -22,8 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// LINT.IfChanges - // Terminal is a pseudoterminal. // // +stateify savable @@ -120,5 +118,3 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { } return tm.slaveKTTY } - -// LINT.ThenChange(../../fs/tty/terminal.go) diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 5ce82b793..67e916525 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -36,7 +36,6 @@ go_library( "gofer.go", "handle.go", "p9file.go", - "pagemath.go", "regular_file.go", "socket.go", "special_file.go", diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 353e2cf5b..6295f6b54 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -869,8 +869,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin Size: stat.Mask&linux.STATX_SIZE != 0, ATime: stat.Mask&linux.STATX_ATIME != 0, MTime: stat.Mask&linux.STATX_MTIME != 0, - ATimeNotSystemTime: stat.Atime.Nsec != linux.UTIME_NOW, - MTimeNotSystemTime: stat.Mtime.Nsec != linux.UTIME_NOW, + ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW, + MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW, }, p9.SetAttr{ Permissions: p9.FileMode(stat.Mode), UID: p9.UID(stat.UID), @@ -928,8 +928,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin // so we can't race with Write or another truncate.) d.dataMu.Unlock() if d.size < oldSize { - oldpgend := pageRoundUp(oldSize) - newpgend := pageRoundUp(d.size) + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(d.size) if oldpgend != newpgend { d.mapsMu.Lock() d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 857f7c74e..0d10cf7ac 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -148,9 +148,9 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, err } // Remove touched pages from the cache. - pgstart := pageRoundDown(uint64(offset)) - pgend := pageRoundUp(uint64(offset + src.NumBytes())) - if pgend < pgstart { + pgstart := usermem.PageRoundDown(uint64(offset)) + pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes())) + if !ok { return 0, syserror.EINVAL } mr := memmap.MappableRange{pgstart, pgend} @@ -306,9 +306,10 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) if fillCache { // Read into the cache, then re-enter the loop to read from the // cache. + gapEnd, _ := usermem.PageRoundUp(gapMR.End) reqMR := memmap.MappableRange{ - Start: pageRoundDown(gapMR.Start), - End: pageRoundUp(gapMR.End), + Start: usermem.PageRoundDown(gapMR.Start), + End: gapEnd, } optMR := gap.Range() err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt) @@ -671,7 +672,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab // Constrain translations to d.size (rounded up) to prevent translation to // pages that may be concurrently truncated. - pgend := pageRoundUp(d.size) + pgend, _ := usermem.PageRoundUp(d.size) var beyondEOF bool if required.End > pgend { if required.Start >= pgend { @@ -818,43 +819,15 @@ type dentryPlatformFile struct { // IncRef implements platform.File.IncRef. func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { d.dataMu.Lock() - seg, gap := d.fdRefs.Find(fr.Start) - for { - switch { - case seg.Ok() && seg.Start() < fr.End: - seg = d.fdRefs.Isolate(seg, fr) - seg.SetValue(seg.Value() + 1) - seg, gap = seg.NextNonEmpty() - case gap.Ok() && gap.Start() < fr.End: - newRange := gap.Range().Intersect(fr) - usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) - seg, gap = d.fdRefs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() - default: - d.fdRefs.MergeAdjacent(fr) - d.dataMu.Unlock() - return - } - } + d.fdRefs.IncRefAndAccount(fr) + d.dataMu.Unlock() } // DecRef implements platform.File.DecRef. func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { d.dataMu.Lock() - seg := d.fdRefs.FindSegment(fr.Start) - - for seg.Ok() && seg.Start() < fr.End { - seg = d.fdRefs.Isolate(seg, fr) - if old := seg.Value(); old == 1 { - usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) - seg = d.fdRefs.Remove(seg).NextSegment() - } else { - seg.SetValue(old - 1) - seg = seg.NextSegment() - } - } - d.fdRefs.MergeAdjacent(fr) + d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() - } // MapInternal implements platform.File.MapInternal. diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 39509f703..ca0fe6d2b 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,6 +8,7 @@ go_library( "control.go", "host.go", "ioctl_unsafe.go", + "mmap.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -23,12 +24,15 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/arch", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", + "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 8caf55a1b..18b127521 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -86,15 +86,13 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) i := &inode{ hostFD: hostFD, - seekable: seekable, + ino: fs.NextIno(), isTTY: opts.IsTTY, - canMap: canMap(uint32(fileType)), wouldBlock: wouldBlock(uint32(fileType)), - ino: fs.NextIno(), - // For simplicity, set offset to 0. Technically, we should use the existing - // offset on the host if the file is seekable. - offset: 0, + seekable: seekable, + canMap: canMap(uint32(fileType)), } + i.pf.inode = i // Non-seekable files can't be memory mapped, assert this. if !i.seekable && i.canMap { @@ -117,6 +115,10 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // i.open will take a reference on d. defer d.DecRef() + + // For simplicity, fileDescription.offset is set to 0. Technically, we + // should only set to 0 on files that are not seekable (sockets, pipes, + // etc.), and use the offset from the host fd otherwise when importing. return i.open(ctx, d.VFSDentry(), mnt, flags) } @@ -189,11 +191,15 @@ type inode struct { // This field is initialized at creation time and is immutable. hostFD int - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. + // ino is an inode number unique within this filesystem. // // This field is initialized at creation time and is immutable. - wouldBlock bool + ino uint64 + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool // seekable is false if the host fd points to a file representing a stream, // e.g. a socket or a pipe. Such files are not seekable and can return @@ -202,29 +208,29 @@ type inode struct { // This field is initialized at creation time and is immutable. seekable bool - // isTTY is true if this file represents a TTY. + // wouldBlock is true if the host FD would return EWOULDBLOCK for + // operations that would block. // // This field is initialized at creation time and is immutable. - isTTY bool + wouldBlock bool + + // Event queue for blocking operations. + queue waiter.Queue // canMap specifies whether we allow the file to be memory mapped. // // This field is initialized at creation time and is immutable. canMap bool - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 + // mapsMu protects mappings. + mapsMu sync.Mutex - // offsetMu protects offset. - offsetMu sync.Mutex - - // offset specifies the current file offset. - offset int64 + // If canMap is true, mappings tracks mappings of hostFD into + // memmap.MappingSpaces. + mappings memmap.MappingSet - // Event queue for blocking operations. - queue waiter.Queue + // pf implements platform.File for mappings of hostFD. + pf inodePlatformFile } // CheckPermissions implements kernfs.Inode. @@ -388,6 +394,21 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err } + oldSize := uint64(hostStat.Size) + if s.Size < oldSize { + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(s.Size) + if oldpgend != newpgend { + i.mapsMu.Lock() + i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + i.mapsMu.Unlock() + } + } } if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 { ts := [2]syscall.Timespec{ @@ -464,9 +485,6 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u return vfsfd, nil } - // For simplicity, set offset to 0. Technically, we should - // only set to 0 on files that are not seekable (sockets, pipes, etc.), - // and use the offset from the host fd otherwise. fd := &fileDescription{inode: i} vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { @@ -487,6 +505,13 @@ type fileDescription struct { // // inode is immutable after fileDescription creation. inode *inode + + // offsetMu protects offset. + offsetMu sync.Mutex + + // offset specifies the current file offset. It is only meaningful when + // inode.seekable is true. + offset int64 } // SetStat implements vfs.FileDescriptionImpl. @@ -532,10 +557,10 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. - i.offsetMu.Lock() - n, err := readFromHostFD(ctx, i.hostFD, dst, i.offset, opts.Flags) - i.offset += n - i.offsetMu.Unlock() + f.offsetMu.Lock() + n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags) + f.offset += n + f.offsetMu.Unlock() return n, err } @@ -572,10 +597,10 @@ func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opt } // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file. - i.offsetMu.Lock() - n, err := writeToHostFD(ctx, i.hostFD, src, i.offset, opts.Flags) - i.offset += n - i.offsetMu.Unlock() + f.offsetMu.Lock() + n, err := writeToHostFD(ctx, i.hostFD, src, f.offset, opts.Flags) + f.offset += n + f.offsetMu.Unlock() return n, err } @@ -600,41 +625,41 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return 0, syserror.ESPIPE } - i.offsetMu.Lock() - defer i.offsetMu.Unlock() + f.offsetMu.Lock() + defer f.offsetMu.Unlock() switch whence { case linux.SEEK_SET: if offset < 0 { - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - i.offset = offset + f.offset = offset case linux.SEEK_CUR: - // Check for overflow. Note that underflow cannot occur, since i.offset >= 0. - if offset > math.MaxInt64-i.offset { - return i.offset, syserror.EOVERFLOW + // Check for overflow. Note that underflow cannot occur, since f.offset >= 0. + if offset > math.MaxInt64-f.offset { + return f.offset, syserror.EOVERFLOW } - if i.offset+offset < 0 { - return i.offset, syserror.EINVAL + if f.offset+offset < 0 { + return f.offset, syserror.EINVAL } - i.offset += offset + f.offset += offset case linux.SEEK_END: var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { - return i.offset, err + return f.offset, err } size := s.Size // Check for overflow. Note that underflow cannot occur, since size >= 0. if offset > math.MaxInt64-size { - return i.offset, syserror.EOVERFLOW + return f.offset, syserror.EOVERFLOW } if size+offset < 0 { - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - i.offset = size + offset + f.offset = size + offset case linux.SEEK_DATA, linux.SEEK_HOLE: // Modifying the offset in the host file table should not matter, since @@ -643,16 +668,16 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i // For reading and writing, we always rely on our internal offset. n, err := unix.Seek(i.hostFD, offset, int(whence)) if err != nil { - return i.offset, err + return f.offset, err } - i.offset = n + f.offset = n default: // Invalid whence. - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - return i.offset, nil + return f.offset, nil } // Sync implements FileDescriptionImpl. @@ -666,8 +691,9 @@ func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts if !f.inode.canMap { return syserror.ENODEV } - // TODO(gvisor.dev/issue/1672): Implement ConfigureMMap and Mappable interface. - return syserror.ENODEV + i := f.inode + i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) + return vfs.GenericConfigureMMap(&f.vfsfd, i, opts) } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go new file mode 100644 index 000000000..8545a82f0 --- /dev/null +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" +) + +// inodePlatformFile implements platform.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// +// inodePlatformFile should only be used if inode.canMap is true. +type inodePlatformFile struct { + *inode + + // fdRefsMu protects fdRefs. + fdRefsMu sync.Mutex + + // fdRefs counts references on platform.File offsets. It is used solely for + // memory accounting. + fdRefs fsutil.FrameRefSet + + // fileMapper caches mappings of the host file represented by this inode. + fileMapper fsutil.HostFileMapper + + // fileMapperInitOnce is used to lazily initialize fileMapper. + fileMapperInitOnce sync.Once +} + +// IncRef implements platform.File.IncRef. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) IncRef(fr platform.FileRange) { + i.fdRefsMu.Lock() + i.fdRefs.IncRefAndAccount(fr) + i.fdRefsMu.Unlock() +} + +// DecRef implements platform.File.DecRef. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) DecRef(fr platform.FileRange) { + i.fdRefsMu.Lock() + i.fdRefs.DecRefAndAccount(fr) + i.fdRefsMu.Unlock() +} + +// MapInternal implements platform.File.MapInternal. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { + return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) +} + +// FD implements platform.File.FD. +func (i *inodePlatformFile) FD() int { + return i.hostFD +} + +// AddMapping implements memmap.Mappable.AddMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + i.mapsMu.Lock() + mapped := i.mappings.AddMapping(ms, ar, offset, writable) + for _, r := range mapped { + i.pf.fileMapper.IncRefOn(r) + } + i.mapsMu.Unlock() + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + i.mapsMu.Lock() + unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable) + for _, r := range unmapped { + i.pf.fileMapper.DecRefOn(r) + } + i.mapsMu.Unlock() +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + return i.AddMapping(ctx, ms, dstAR, offset, writable) +} + +// Translate implements memmap.Mappable.Translate. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + mr := optional + return []memmap.Translation{ + { + Source: mr, + File: &i.pf, + Offset: mr.Start, + Perms: usermem.AnyAccess, + }, + }, nil +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) InvalidateUnsavable(ctx context.Context) error { + // We expect the same host fd across save/restore, so all translations + // should be valid. + return nil +} diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index a2d9649e7..007be1572 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -52,7 +52,6 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", - "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", @@ -96,6 +95,7 @@ go_test( "pipe_test.go", "regular_file_test.go", "stat_test.go", + "tmpfs_test.go", ], library = ":tmpfs", deps = [ @@ -105,7 +105,6 @@ go_test( "//pkg/sentry/contexttest", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 36ffcb592..80fa7b29d 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -16,6 +16,7 @@ package tmpfs import ( "fmt" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -24,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -76,8 +78,8 @@ afterSymlink: return nil, err } if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { - // TODO(gvisor.dev/issue/1197): Symlink traversals updates - // access time. + // Symlink traversal updates access time. + atomic.StoreInt64(&d.inode.atime, d.inode.fs.clock.Now().Nanoseconds()) if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err } @@ -361,8 +363,8 @@ afterTrailingSymlink: } // Do we need to resolve a trailing symlink? if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { - // TODO(gvisor.dev/issue/1197): Symlink traversals updates - // access time. + // Symlink traversal updates access time. + atomic.StoreInt64(&child.inode.atime, child.inode.fs.clock.Now().Nanoseconds()) if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err } @@ -636,12 +638,19 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { fs.mu.RLock() defer fs.mu.RUnlock() - _, err := resolveLocked(rp) - if err != nil { + if _, err := resolveLocked(rp); err != nil { return linux.Statfs{}, err } - // TODO(gvisor.dev/issue/1197): Actually implement statfs. - return linux.Statfs{}, syserror.ENOSYS + statfs := linux.Statfs{ + Type: linux.TMPFS_MAGIC, + BlockSize: usermem.PageSize, + FragmentSize: usermem.PageSize, + NameLength: linux.NAME_MAX, + // TODO(b/29637826): Allow configuring a tmpfs size and enforce it. + Blocks: 0, + BlocksFree: 0, + } + return statfs, nil } // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. @@ -763,5 +772,24 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { fs.mu.RLock() defer fs.mu.RUnlock() - return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) + mnt := vd.Mount() + d := vd.Dentry().Impl().(*dentry) + for { + if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { + return vfs.PrependPathAtVFSRootError{} + } + if &d.vfsd == mnt.Root() { + return nil + } + if d.parent == nil { + if d.name != "" { + // This must be an anonymous memfd file. + b.PrependComponent("/" + d.name) + return vfs.PrependPathSyntheticError{} + } + return vfs.PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } } diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 57e5e28ec..3f433d666 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -88,6 +88,7 @@ type regularFile struct { func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { file := ®ularFile{ memFile: fs.memFile, + seals: linux.F_SEAL_SEAL, } file.inode.init(file, fs, creds, linux.S_IFREG|mode) file.inode.nlink = 1 // from parent directory @@ -577,3 +578,44 @@ exitLoop: return done, retErr } + +// GetSeals returns the current set of seals on a memfd inode. +func GetSeals(fd *vfs.FileDescription) (uint32, error) { + f, ok := fd.Impl().(*regularFileFD) + if !ok { + return 0, syserror.EINVAL + } + rf := f.inode().impl.(*regularFile) + rf.dataMu.RLock() + defer rf.dataMu.RUnlock() + return rf.seals, nil +} + +// AddSeals adds new file seals to a memfd inode. +func AddSeals(fd *vfs.FileDescription, val uint32) error { + f, ok := fd.Impl().(*regularFileFD) + if !ok { + return syserror.EINVAL + } + rf := f.inode().impl.(*regularFile) + rf.mapsMu.Lock() + defer rf.mapsMu.Unlock() + rf.dataMu.RLock() + defer rf.dataMu.RUnlock() + + if rf.seals&linux.F_SEAL_SEAL != 0 { + // Seal applied which prevents addition of any new seals. + return syserror.EPERM + } + + // F_SEAL_WRITE can only be added if there are no active writable maps. + if rf.seals&linux.F_SEAL_WRITE == 0 && val&linux.F_SEAL_WRITE != 0 { + if rf.writableMappingPages > 0 { + return syserror.EBUSY + } + } + + // Seals can only be added, never removed. + rf.seals |= val + return nil +} diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go index 0399725cf..64e1c40ad 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -18,152 +18,16 @@ import ( "bytes" "fmt" "io" - "sync/atomic" "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) -// nextFileID is used to generate unique file names. -var nextFileID int64 - -// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error -// is not nil, then cleanup should be called when the root is no longer needed. -func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) { - creds := auth.CredentialsFromContext(ctx) - - vfsObj := &vfs.VirtualFilesystem{} - if err := vfsObj.Init(); err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) - } - - vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, - }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) - if err != nil { - return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) - } - root := mntns.Root() - return vfsObj, root, func() { - root.DecRef() - mntns.DecRef() - }, nil -} - -// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If -// the returned err is not nil, then cleanup should be called when the FD is no -// longer needed. -func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the file that will be write/read. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: linux.ModeRegular | mode, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err) - } - - return fd, cleanup, nil -} - -// newDirFD is like newFileFD, but for directories. -func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the dir. - if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(dirname), - }, &vfs.MkdirOptions{ - Mode: linux.ModeDirectory | mode, - }); err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err) - } - - // Open the dir and return it. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(dirname), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY | linux.O_DIRECTORY, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err) - } - - return fd, cleanup, nil -} - -// newPipeFD is like newFileFD, but for pipes. -func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { - creds := auth.CredentialsFromContext(ctx) - vfsObj, root, cleanup, err := newTmpfsRoot(ctx) - if err != nil { - return nil, nil, err - } - - pipename := fmt.Sprintf("tmpfs-test-pipe-%d", atomic.AddInt64(&nextFileID, 1)) - - // Create the pipe. - if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(pipename), - }, &vfs.MknodOptions{ - Mode: linux.ModeNamedPipe | mode, - }); err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to create pipe %q: %v", pipename, err) - } - - // Open the pipe and return it. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(pipename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - cleanup() - return nil, nil, fmt.Errorf("failed to open pipe %q: %v", pipename, err) - } - - return fd, cleanup, nil -} - // Test that we can write some data to a file and read it back.` func TestSimpleWriteRead(t *testing.T) { ctx := contexttest.Context(t) diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go index 60c2c980e..f7ee4aab2 100644 --- a/pkg/sentry/fsimpl/tmpfs/stat_test.go +++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -29,7 +29,6 @@ func TestStatAfterCreate(t *testing.T) { mode := linux.FileMode(0644) // Run with different file types. - // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets. for _, typ := range []string{"file", "dir", "pipe"} { t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { var ( @@ -175,7 +174,6 @@ func TestSetStat(t *testing.T) { mode := linux.FileMode(0644) // Run with different file types. - // TODO(gvisor.dev/issue/1197): Also test symlinks and sockets. for _, typ := range []string{"file", "dir", "pipe"} { t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) { var ( diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 405928bd0..1e781aecd 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -94,7 +94,7 @@ type FilesystemOpts struct { } // GetFilesystem implements vfs.FilesystemType.GetFilesystem. -func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { +func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) if memFileProvider == nil { panic("MemoryFileProviderFromContext returned nil") @@ -139,6 +139,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return &fs.vfsfs, &root.vfsd, nil } +// NewFilesystem returns a new tmpfs filesystem. +func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*vfs.Filesystem, *vfs.Dentry, error) { + return FilesystemType{}.GetFilesystem(ctx, vfsObj, creds, "", vfs.GetFilesystemOptions{}) +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release() { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) @@ -658,3 +663,34 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name) } + +// NewMemfd creates a new tmpfs regular file and file description that can back +// an anonymous fd created by memfd_create. +func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name string) (*vfs.FileDescription, error) { + fs, ok := mount.Filesystem().Impl().(*filesystem) + if !ok { + panic("NewMemfd() called with non-tmpfs mount") + } + + // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with + // S_IRWXUGO. + mode := linux.FileMode(0777) + inode := fs.newRegularFile(creds, mode) + rf := inode.impl.(*regularFile) + if allowSeals { + rf.seals = 0 + } + + d := fs.newDentry(inode) + defer d.DecRef() + d.name = name + + // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with + // FMODE_READ | FMODE_WRITE. + var fd regularFileFD + flags := uint32(linux.O_RDWR) + if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go new file mode 100644 index 000000000..a240fb276 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go @@ -0,0 +1,156 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// nextFileID is used to generate unique file names. +var nextFileID int64 + +// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error +// is not nil, then cleanup should be called when the root is no longer needed. +func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) { + creds := auth.CredentialsFromContext(ctx) + + vfsObj := &vfs.VirtualFilesystem{} + if err := vfsObj.Init(); err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) + } + + vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + if err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) + } + root := mntns.Root() + return vfsObj, root, func() { + root.DecRef() + mntns.DecRef() + }, nil +} + +// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If +// the returned err is not nil, then cleanup should be called when the FD is no +// longer needed. +func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { + creds := auth.CredentialsFromContext(ctx) + vfsObj, root, cleanup, err := newTmpfsRoot(ctx) + if err != nil { + return nil, nil, err + } + + filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1)) + + // Create the file that will be write/read. + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, + Mode: linux.ModeRegular | mode, + }) + if err != nil { + cleanup() + return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err) + } + + return fd, cleanup, nil +} + +// newDirFD is like newFileFD, but for directories. +func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { + creds := auth.CredentialsFromContext(ctx) + vfsObj, root, cleanup, err := newTmpfsRoot(ctx) + if err != nil { + return nil, nil, err + } + + dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1)) + + // Create the dir. + if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(dirname), + }, &vfs.MkdirOptions{ + Mode: linux.ModeDirectory | mode, + }); err != nil { + cleanup() + return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err) + } + + // Open the dir and return it. + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(dirname), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY | linux.O_DIRECTORY, + }) + if err != nil { + cleanup() + return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err) + } + + return fd, cleanup, nil +} + +// newPipeFD is like newFileFD, but for pipes. +func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) { + creds := auth.CredentialsFromContext(ctx) + vfsObj, root, cleanup, err := newTmpfsRoot(ctx) + if err != nil { + return nil, nil, err + } + + name := fmt.Sprintf("tmpfs-test-%d", atomic.AddInt64(&nextFileID, 1)) + + if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(name), + }, &vfs.MknodOptions{ + Mode: linux.ModeNamedPipe | mode, + }); err != nil { + cleanup() + return nil, nil, fmt.Errorf("failed to create pipe %q: %v", name, err) + } + + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(name), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + cleanup() + return nil, nil, fmt.Errorf("failed to open pipe %q: %v", name, err) + } + + return fd, cleanup, nil +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 8104f50f3..a28eab8b8 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -173,6 +173,7 @@ go_library( "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/fsimpl/timerfd", + "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/hostcpu", "//pkg/sentry/inet", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 3617da8c6..5efeb3767 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -53,6 +53,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/timerfd" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/hostcpu" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -259,6 +260,10 @@ type Kernel struct { // syscalls (as opposed to named pipes created by mknod()). pipeMount *vfs.Mount + // shmMount is the Mount used for anonymous files created by the + // memfd_create() syscalls. It is analagous to Linux's shm_mnt. + shmMount *vfs.Mount + // socketMount is the Mount used for sockets created by the socket() and // socketpair() syscalls. There are several cases where a socket dentry will // not be contained in socketMount: @@ -330,6 +335,9 @@ func (k *Kernel) Init(args InitKernelArgs) error { if args.Timekeeper == nil { return fmt.Errorf("Timekeeper is nil") } + if args.Timekeeper.clocks == nil { + return fmt.Errorf("Must call Timekeeper.SetClocks() before Kernel.Init()") + } if args.RootUserNamespace == nil { return fmt.Errorf("RootUserNamespace is nil") } @@ -384,6 +392,18 @@ func (k *Kernel) Init(args InitKernelArgs) error { } k.pipeMount = pipeMount + tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(k.SupervisorContext(), &k.vfs, auth.NewRootCredentials(k.rootUserNamespace)) + if err != nil { + return fmt.Errorf("failed to create tmpfs filesystem: %v", err) + } + defer tmpfsFilesystem.DecRef() + defer tmpfsRoot.DecRef() + shmMount, err := k.vfs.NewDisconnectedMount(tmpfsFilesystem, tmpfsRoot, &vfs.MountOptions{}) + if err != nil { + return fmt.Errorf("failed to create tmpfs mount: %v", err) + } + k.shmMount = shmMount + socketFilesystem, err := sockfs.NewFilesystem(&k.vfs) if err != nil { return fmt.Errorf("failed to create sockfs filesystem: %v", err) @@ -1656,6 +1676,11 @@ func (k *Kernel) PipeMount() *vfs.Mount { return k.pipeMount } +// ShmMount returns the tmpfs mount. +func (k *Kernel) ShmMount() *vfs.Mount { + return k.shmMount +} + // SocketMount returns the sockfs mount. func (k *Kernel) SocketMount() *vfs.Mount { return k.socketMount diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 5a1d4fd57..aacf28da2 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -144,7 +144,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume if v > math.MaxInt32 { v = math.MaxInt32 // Silently truncate. } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index c9db78e06..a5903b0b5 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -199,10 +199,10 @@ func (t *Task) doSyscall() taskRunState { // // On x86, register rax was shared by syscall number and return // value, and at the entry of the syscall handler, the rax was - // saved to regs.orig_rax which was exposed to user space. + // saved to regs.orig_rax which was exposed to userspace. // But on arm64, syscall number was passed through X8, and the X0 // was shared by the first syscall argument and return value. The - // X0 was saved to regs.orig_x0 which was not exposed to user space. + // X0 was saved to regs.orig_x0 which was not exposed to userspace. // So we have to do the same operation here to save the X0 value // into the task context. t.Arch().SyscallSaveOrig() diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 73591dab7..a036ce53c 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -25,6 +25,7 @@ go_template_instance( out = "vma_set.go", consts = { "minDegree": "8", + "trackGaps": "1", }, imports = { "usermem": "gvisor.dev/gvisor/pkg/usermem", diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 9a14e69e6..16d8207e9 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -195,7 +195,7 @@ func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange { // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextGap() { + for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift up to match the alignment? if offset := uint64(gr.Start) % alignment; offset != 0 { @@ -214,7 +214,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevGap() { + for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift down to match the alignment? start := gr.End - usermem.Addr(length) diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 444a83913..a6345010d 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -38,6 +38,12 @@ func SaveVRegs(*byte) // LoadVRegs loads V0-V31 registers. func LoadVRegs(*byte) +// GetTLS returns the value of TPIDR_EL0 register. +func GetTLS() (value uint64) + +// SetTLS writes the TPIDR_EL0 value. +func SetTLS(value uint64) + // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 0e6a6235b..b63e14b41 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,16 @@ #include "funcdata.h" #include "textflag.h" +TEXT ·GetTLS(SB),NOSPLIT,$0-8 + MRS TPIDR_EL0, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·SetTLS(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + MSR R1, TPIDR_EL0 + RET + TEXT ·CPACREL1(SB),NOSPLIT,$0-8 WORD $0xd5381041 // MRS CPACR_EL1, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index b49433326..c11e82c10 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -555,7 +555,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if uint64(src.NumBytes()) != srcs.NumBytes() { return 0, nil } - if srcs.IsEmpty() { + if srcs.IsEmpty() && len(controlBuf) == 0 { return 0, nil } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9d032f052..60df51dae 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1321,6 +1321,29 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_SYNCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPSynCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + + case linux.TCP_WINDOW_CLAMP: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPWindowClampOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil default: emitUnimplementedEventTCP(t, name) } @@ -1790,6 +1813,22 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + case linux.TCP_SYNCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v))) + + case linux.TCP_WINDOW_CLAMP: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -2679,7 +2718,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) @@ -2748,7 +2787,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc if v > math.MaxInt32 { v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) @@ -2764,7 +2803,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index df0d0f461..39f2b79ec 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -16,7 +16,6 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -26,7 +25,6 @@ import ( // doSplice implements a blocking splice operation. func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) { - log.Infof("NLAC: doSplice opts: %+v", opts) if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) { return 0, syserror.EINVAL } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index c32f942fb..f882ef840 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -13,6 +13,7 @@ go_library( "fscontext.go", "getdents.go", "ioctl.go", + "memfd.go", "mmap.go", "path.go", "pipe.go", @@ -43,6 +44,7 @@ go_library( "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/signalfd", "//pkg/sentry/fsimpl/timerfd", + "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 8181d80f4..ca0f7fd1e 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -17,6 +17,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" @@ -157,6 +158,15 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.EBADF } return uintptr(pipefile.PipeSize()), nil, nil + case linux.F_GET_SEALS: + val, err := tmpfs.GetSeals(file) + return uintptr(val), nil, err + case linux.F_ADD_SEALS: + if !file.IsWritable() { + return 0, nil, syserror.EPERM + } + err := tmpfs.AddSeals(file, args[2].Uint()) + return 0, nil, err default: // TODO(gvisor.dev/issue/1623): Everything else is not yet supported. return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go new file mode 100644 index 000000000..bbe248d17 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + memfdPrefix = "memfd:" + memfdMaxNameLen = linux.NAME_MAX - len(memfdPrefix) + memfdAllFlags = uint32(linux.MFD_CLOEXEC | linux.MFD_ALLOW_SEALING) +) + +// MemfdCreate implements the linux syscall memfd_create(2). +func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + flags := args[1].Uint() + + if flags&^memfdAllFlags != 0 { + // Unknown bits in flags. + return 0, nil, syserror.EINVAL + } + + allowSeals := flags&linux.MFD_ALLOW_SEALING != 0 + cloExec := flags&linux.MFD_CLOEXEC != 0 + + name, err := t.CopyInString(addr, memfdMaxNameLen) + if err != nil { + return 0, nil, err + } + + shmMount := t.Kernel().ShmMount() + file, err := tmpfs.NewMemfd(shmMount, t.Credentials(), allowSeals, memfdPrefix+name) + if err != nil { + return 0, nil, err + } + + fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ + CloseOnExec: cloExec, + }) + if err != nil { + return 0, nil, err + } + + return uintptr(fd), nil, nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 4e61f1452..09ecfed26 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -246,73 +246,104 @@ func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err } - opts := vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: linux.STATX_ATIME | linux.STATX_MTIME, - }, - } - if timesAddr == 0 { - opts.Stat.Atime.Nsec = linux.UTIME_NOW - opts.Stat.Mtime.Nsec = linux.UTIME_NOW - } else { - var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { - return 0, nil, err - } - opts.Stat.Atime = linux.StatxTimestamp{ - Sec: times[0].Sec, - Nsec: uint32(times[0].Usec * 1000), - } - opts.Stat.Mtime = linux.StatxTimestamp{ - Sec: times[1].Sec, - Nsec: uint32(times[1].Usec * 1000), - } + var opts vfs.SetStatOptions + if err := populateSetStatOptionsForUtimes(t, timesAddr, &opts); err != nil { + return 0, nil, err } return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts) } -// Utimensat implements Linux syscall utimensat(2). -func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// Futimesat implements Linux syscall futimesat(2). +func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { dirfd := args[0].Int() pathAddr := args[1].Pointer() timesAddr := args[2].Pointer() - flags := args[3].Int() - if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 { - return 0, nil, syserror.EINVAL - } - - path, err := copyInPath(t, pathAddr) - if err != nil { - return 0, nil, err + // "If filename is NULL and dfd refers to an open file, then operate on the + // file. Otherwise look up filename, possibly using dfd as a starting + // point." - fs/utimes.c + var path fspath.Path + shouldAllowEmptyPath := allowEmptyPath + if dirfd == linux.AT_FDCWD || pathAddr != 0 { + var err error + path, err = copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + shouldAllowEmptyPath = disallowEmptyPath } var opts vfs.SetStatOptions - if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil { + if err := populateSetStatOptionsForUtimes(t, timesAddr, &opts); err != nil { return 0, nil, err } - return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts) + return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, followFinalSymlink, &opts) } -// Futimens implements Linux syscall futimens(2). -func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - fd := args[0].Int() - timesAddr := args[1].Pointer() - - file := t.GetFileVFS2(fd) - if file == nil { - return 0, nil, syserror.EBADF +func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error { + if timesAddr == 0 { + opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME + opts.Stat.Atime.Nsec = linux.UTIME_NOW + opts.Stat.Mtime.Nsec = linux.UTIME_NOW + return nil } - defer file.DecRef() + var times [2]linux.Timeval + if _, err := t.CopyIn(timesAddr, ×); err != nil { + return err + } + if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 { + return syserror.EINVAL + } + opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME + opts.Stat.Atime = linux.StatxTimestamp{ + Sec: times[0].Sec, + Nsec: uint32(times[0].Usec * 1000), + } + opts.Stat.Mtime = linux.StatxTimestamp{ + Sec: times[1].Sec, + Nsec: uint32(times[1].Usec * 1000), + } + return nil +} +// Utimensat implements Linux syscall utimensat(2). +func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirfd := args[0].Int() + pathAddr := args[1].Pointer() + timesAddr := args[2].Pointer() + flags := args[3].Int() + + // Linux requires that the UTIME_OMIT check occur before checking path or + // flags. var opts vfs.SetStatOptions if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil { return 0, nil, err } + if opts.Stat.Mask == 0 { + return 0, nil, nil + } - return 0, nil, file.SetStat(t, opts) + if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 { + return 0, nil, syserror.EINVAL + } + + // "If filename is NULL and dfd refers to an open file, then operate on the + // file. Otherwise look up filename, possibly using dfd as a starting + // point." - fs/utimes.c + var path fspath.Path + shouldAllowEmptyPath := allowEmptyPath + if dirfd == linux.AT_FDCWD || pathAddr != 0 { + var err error + path, err = copyInPath(t, pathAddr) + if err != nil { + return 0, nil, err + } + shouldAllowEmptyPath = disallowEmptyPath + } + + return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts) } func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error { @@ -327,6 +358,9 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op return err } if times[0].Nsec != linux.UTIME_OMIT { + if times[0].Nsec != linux.UTIME_NOW && (times[0].Nsec < 0 || times[0].Nsec > 999999999) { + return syserror.EINVAL + } opts.Stat.Mask |= linux.STATX_ATIME opts.Stat.Atime = linux.StatxTimestamp{ Sec: times[0].Sec, @@ -334,6 +368,9 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op } } if times[1].Nsec != linux.UTIME_OMIT { + if times[1].Nsec != linux.UTIME_NOW && (times[1].Nsec < 0 || times[1].Nsec > 999999999) { + return syserror.EINVAL + } opts.Stat.Mask |= linux.STATX_MTIME opts.Stat.Mtime = linux.StatxTimestamp{ Sec: times[1].Sec, diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 9c04677f1..a332d01bd 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -123,7 +123,7 @@ func Override() { s.Table[258] = syscalls.Supported("mkdirat", Mkdirat) s.Table[259] = syscalls.Supported("mknodat", Mknodat) s.Table[260] = syscalls.Supported("fchownat", Fchownat) - s.Table[261] = syscalls.Supported("futimens", Futimens) + s.Table[261] = syscalls.Supported("futimesat", Futimesat) s.Table[262] = syscalls.Supported("newfstatat", Newfstatat) s.Table[263] = syscalls.Supported("unlinkat", Unlinkat) s.Table[264] = syscalls.Supported("renameat", Renameat) @@ -158,7 +158,7 @@ func Override() { s.Table[306] = syscalls.Supported("syncfs", Syncfs) s.Table[307] = syscalls.Supported("sendmmsg", SendMMsg) s.Table[316] = syscalls.Supported("renameat2", Renameat2) - delete(s.Table, 319) // memfd_create + s.Table[319] = syscalls.Supported("memfd_create", MemfdCreate) s.Table[322] = syscalls.Supported("execveat", Execveat) s.Table[327] = syscalls.Supported("preadv2", Preadv2) s.Table[328] = syscalls.Supported("pwritev2", Pwritev2) diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 921af9d63..2b1350135 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -47,6 +47,7 @@ go_library( "state.go", "stats.go", ], + marshal = False, stateify = False, visibility = ["//:sandbox"], deps = [ diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 0e35d7d17..d0d77e19c 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -39,6 +39,8 @@ go_library( "seqcount.go", "sync.go", ], + marshal = False, + stateify = False, ) go_test( diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 29454c4b9..4c6f808e5 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -66,6 +66,14 @@ const ( TCPOptionSACK = 5 ) +// Option Lengths. +const ( + TCPOptionMSSLength = 4 + TCPOptionTSLength = 10 + TCPOptionWSLength = 3 + TCPOptionSackPermittedLength = 2 +) + // TCPFields contains the fields of a TCP packet. It is used to describe the // fields of a packet that needs to be encoded. type TCPFields struct { @@ -494,14 +502,11 @@ func ParseTCPOptions(b []byte) TCPOptions { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeMSSOption(mss uint32, b []byte) int { - // mssOptionSize is the number of bytes in a valid MSS option. - const mssOptionSize = 4 - - if len(b) < mssOptionSize { + if len(b) < TCPOptionMSSLength { return 0 } - b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) - return mssOptionSize + b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss) + return TCPOptionMSSLength } // EncodeWSOption encodes the WS TCP option with the WS value in the @@ -509,10 +514,10 @@ func EncodeMSSOption(mss uint32, b []byte) int { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeWSOption(ws int, b []byte) int { - if len(b) < 3 { + if len(b) < TCPOptionWSLength { return 0 } - b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) + b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws) return int(b[1]) } @@ -521,10 +526,10 @@ func EncodeWSOption(ws int, b []byte) int { // just returns without encoding anything. It returns the number of bytes // written to the provided buffer. func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { - if len(b) < 10 { + if len(b) < TCPOptionTSLength { return 0 } - b[0], b[1] = TCPOptionTS, 10 + b[0], b[1] = TCPOptionTS, TCPOptionTSLength binary.BigEndian.PutUint32(b[2:], tsVal) binary.BigEndian.PutUint32(b[6:], tsEcr) return int(b[1]) @@ -535,11 +540,11 @@ func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { // encoding anything. It returns the number of bytes written to the provided // buffer. func EncodeSACKPermittedOption(b []byte) int { - if len(b) < 2 { + if len(b) < TCPOptionSackPermittedLength { return 0 } - b[0], b[1] = TCPOptionSACKPermitted, 2 + b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength return int(b[1]) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b39ffa9fb..0ab4c3e19 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -235,11 +235,11 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to user space since + // CopiedBytes is the number of bytes copied to userspace since // this measure began. CopiedBytes int - // PrevCopiedBytes is the number of bytes copied to user space in + // PrevCopiedBytes is the number of bytes copied to userspace in // the previous RTT period. PrevCopiedBytes int diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 1ca4088c9..b7b227328 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -110,6 +110,71 @@ var ( ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ) +var messageToError map[string]*Error + +var populate sync.Once + +// StringToError converts an error message to the error. +func StringToError(s string) *Error { + populate.Do(func() { + var errors = []*Error{ + ErrUnknownProtocol, + ErrUnknownNICID, + ErrUnknownDevice, + ErrUnknownProtocolOption, + ErrDuplicateNICID, + ErrDuplicateAddress, + ErrNoRoute, + ErrBadLinkEndpoint, + ErrAlreadyBound, + ErrInvalidEndpointState, + ErrAlreadyConnecting, + ErrAlreadyConnected, + ErrNoPortAvailable, + ErrPortInUse, + ErrBadLocalAddress, + ErrClosedForSend, + ErrClosedForReceive, + ErrWouldBlock, + ErrConnectionRefused, + ErrTimeout, + ErrAborted, + ErrConnectStarted, + ErrDestinationRequired, + ErrNotSupported, + ErrQueueSizeNotSupported, + ErrNotConnected, + ErrConnectionReset, + ErrConnectionAborted, + ErrNoSuchFile, + ErrInvalidOptionValue, + ErrNoLinkAddress, + ErrBadAddress, + ErrNetworkUnreachable, + ErrMessageTooLong, + ErrNoBufferSpace, + ErrBroadcastDisabled, + ErrNotPermitted, + ErrAddressFamilyNotSupported, + } + + messageToError = make(map[string]*Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} + // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") @@ -622,6 +687,19 @@ const ( // // A zero value indicates the default. TTLOption + + // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of + // SYN retransmits that TCP should send before aborting the attempt to + // connect. It cannot exceed 255. + // + // NOTE: This option is currently only stubbed out and is no-op. + TCPSynCountOption + + // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size + // of the advertised window to this value. + // + // NOTE: This option is currently only stubed out and is a no-op + TCPWindowClampOption ) // ErrorOption is used in GetSockOpt to specify that the last error reported by @@ -685,11 +763,23 @@ type TCPDeferAcceptOption time.Duration // default MinRTO used by the Stack. type TCPMinRTOOption time.Duration +// TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding +// default MaxRTO used by the Stack. +type TCPMaxRTOOption time.Duration + +// TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum number of retransmits after which we time out the connection. +type TCPMaxRetriesOption uint64 + // TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify // the number of endpoints that can be in SYN-RCVD state before the stack // switches to using SYN cookies. type TCPSynRcvdCountThresholdOption uint64 +// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide +// default for number of times SYN is retransmitted before aborting a connect. +type TCPSynRetriesOption uint8 + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 07d3e64c8..b5ba972f1 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -470,6 +470,17 @@ type endpoint struct { // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 + // maxSynRetries is the maximum number of SYN retransmits that TCP should + // send before aborting the attempt to connect. It cannot exceed 255. + // + // NOTE: This is currently a no-op and does not change the SYN + // retransmissions. + maxSynRetries uint8 + + // windowClamp is used to bound the size of the advertised window to + // this value. + windowClamp uint32 + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -795,8 +806,10 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, - uniqueID: s.UniqueID(), - txHash: s.Rand().Uint32(), + uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), + windowClamp: DefaultReceiveBufferSize, + maxSynRetries: DefaultSynRetries, } var ss SendBufferSizeOption @@ -829,6 +842,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.tcpLingerTimeout = time.Duration(tcpLT) } + var synRetries tcpip.TCPSynRetriesOption + if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil { + e.maxSynRetries = uint8(synRetries) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -1079,7 +1097,7 @@ func (e *endpoint) initialReceiveWindow() int { } // ModerateRecvBuf adjusts the receive buffer and the advertised window -// based on the number of bytes copied to user space. +// based on the number of bytes copied to userspace. func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() @@ -1603,6 +1621,36 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.ttl = uint8(v) e.UnlockUser() + case tcpip.TCPSynCountOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.maxSynRetries = uint8(v) + e.UnlockUser() + + case tcpip.TCPWindowClampOption: + if v == 0 { + e.LockUser() + switch e.EndpointState() { + case StateClose, StateInitial: + e.windowClamp = 0 + e.UnlockUser() + return nil + default: + e.UnlockUser() + return tcpip.ErrInvalidOptionValue + } + } + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if v < rs.Min/2 { + v = rs.Min / 2 + } + } + e.LockUser() + e.windowClamp = uint32(v) + e.UnlockUser() } return nil } @@ -1826,6 +1874,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.UnlockUser() return v, nil + case tcpip.TCPSynCountOption: + e.LockUser() + v := int(e.maxSynRetries) + e.UnlockUser() + return v, nil + + case tcpip.TCPWindowClampOption: + e.LockUser() + v := int(e.windowClamp) + e.UnlockUser() + return v, nil + default: return -1, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 8b7562396..fc43c11e2 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -314,7 +314,7 @@ func (e *endpoint) loadLastError(s string) { return } - e.lastError = loadError(s) + e.lastError = tcpip.StringToError(s) } // saveHardError is invoked by stateify. @@ -332,71 +332,7 @@ func (e *EndpointInfo) loadHardError(s string) { return } - e.HardError = loadError(s) -} - -var messageToError map[string]*tcpip.Error - -var populate sync.Once - -func loadError(s string) *tcpip.Error { - populate.Do(func() { - var errors = []*tcpip.Error{ - tcpip.ErrUnknownProtocol, - tcpip.ErrUnknownNICID, - tcpip.ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID, - tcpip.ErrDuplicateAddress, - tcpip.ErrNoRoute, - tcpip.ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound, - tcpip.ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected, - tcpip.ErrNoPortAvailable, - tcpip.ErrPortInUse, - tcpip.ErrBadLocalAddress, - tcpip.ErrClosedForSend, - tcpip.ErrClosedForReceive, - tcpip.ErrWouldBlock, - tcpip.ErrConnectionRefused, - tcpip.ErrTimeout, - tcpip.ErrAborted, - tcpip.ErrConnectStarted, - tcpip.ErrDestinationRequired, - tcpip.ErrNotSupported, - tcpip.ErrQueueSizeNotSupported, - tcpip.ErrNotConnected, - tcpip.ErrConnectionReset, - tcpip.ErrConnectionAborted, - tcpip.ErrNoSuchFile, - tcpip.ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress, - tcpip.ErrBadAddress, - tcpip.ErrNetworkUnreachable, - tcpip.ErrMessageTooLong, - tcpip.ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled, - tcpip.ErrNotPermitted, - tcpip.ErrAddressFamilyNotSupported, - } - - messageToError = make(map[string]*tcpip.Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e + e.HardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index cfd9a4e8e..2a2a7ddeb 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -64,6 +64,10 @@ const ( // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger // in TIME_WAIT state before being marked closed. DefaultTCPTimeWaitTimeout = 60 * time.Second + + // DefaultSynRetries is the default value for the number of SYN retransmits + // before a connect is aborted. + DefaultSynRetries = 6 ) // SACKEnabled option can be used to enable SACK support in the TCP @@ -163,7 +167,10 @@ type protocol struct { tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration minRTO time.Duration + maxRTO time.Duration + maxRetries uint32 synRcvdCount synRcvdCounter + synRetries uint8 dispatcher *dispatcher } @@ -340,12 +347,36 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPMaxRTOOption: + if v < 0 { + v = tcpip.TCPMaxRTOOption(MaxRTO) + } + p.mu.Lock() + p.maxRTO = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPMaxRetriesOption: + p.mu.Lock() + p.maxRetries = uint32(v) + p.mu.Unlock() + return nil + case tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() p.synRcvdCount.SetThreshold(uint64(v)) p.mu.Unlock() return nil + case tcpip.TCPSynRetriesOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.synRetries = uint8(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -414,12 +445,30 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *tcpip.TCPMaxRTOOption: + p.mu.RLock() + *v = tcpip.TCPMaxRTOOption(p.maxRTO) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMaxRetriesOption: + p.mu.RLock() + *v = tcpip.TCPMaxRetriesOption(p.maxRetries) + p.mu.RUnlock() + return nil + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.RLock() *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) p.mu.RUnlock() return nil + case *tcpip.TCPSynRetriesOption: + p.mu.RLock() + *v = tcpip.TCPSynRetriesOption(p.synRetries) + p.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -452,6 +501,9 @@ func NewProtocol() stack.TransportProtocol { tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), + synRetries: DefaultSynRetries, minRTO: MinRTO, + maxRTO: MaxRTO, + maxRetries: MaxRetries, } } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 9e547a221..06dc9b7d7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -43,7 +43,8 @@ const ( nDupAckThreshold = 3 // MaxRetries is the maximum number of probe retries sender does - // before timing out the connection, Linux default TCP_RETR2. + // before timing out the connection. + // Linux default TCP_RETR2, net.ipv4.tcp_retries2. MaxRetries = 15 ) @@ -165,6 +166,12 @@ type sender struct { // minRTO is the minimum permitted value for sender.rto. minRTO time.Duration + // maxRTO is the maximum permitted value for sender.rto. + maxRTO time.Duration + + // maxRetries is the maximum permitted retransmissions. + maxRetries uint32 + // maxPayloadSize is the maximum size of the payload of a given segment. // It is initialized on demand. maxPayloadSize int @@ -276,12 +283,24 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // etc. s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) - // Get Stack wide minRTO. - var v tcpip.TCPMinRTOOption - if err := ep.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + // Get Stack wide config. + var minRTO tcpip.TCPMinRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &minRTO); err != nil { panic(fmt.Sprintf("unable to get minRTO from stack: %s", err)) } - s.minRTO = time.Duration(v) + s.minRTO = time.Duration(minRTO) + + var maxRTO tcpip.TCPMaxRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRTO); err != nil { + panic(fmt.Sprintf("unable to get maxRTO from stack: %s", err)) + } + s.maxRTO = time.Duration(maxRTO) + + var maxRetries tcpip.TCPMaxRetriesOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRetries); err != nil { + panic(fmt.Sprintf("unable to get maxRetries from stack: %s", err)) + } + s.maxRetries = uint32(maxRetries) return s } @@ -485,7 +504,7 @@ func (s *sender) retransmitTimerExpired() bool { } elapsed := time.Since(s.firstRetransmittedSegXmitTime) - remaining := MaxRTO + remaining := s.maxRTO if uto != 0 { // Cap to the user specified timeout if one is specified. remaining = uto - elapsed @@ -494,24 +513,17 @@ func (s *sender) retransmitTimerExpired() bool { // Always honor the user-timeout irrespective of whether the zero // window probes were acknowledged. // net/ipv4/tcp_timer.c::tcp_probe_timer() - if remaining <= 0 || s.unackZeroWindowProbes >= MaxRetries { + if remaining <= 0 || s.unackZeroWindowProbes >= s.maxRetries { return false } - if s.rto >= MaxRTO { - // RFC 1122 section: 4.2.2.17 - // A TCP MAY keep its offered receive window closed - // indefinitely. As long as the receiving TCP continues to - // send acknowledgments in response to the probe segments, the - // sending TCP MUST allow the connection to stay open. - if !(s.zeroWindowProbing && s.unackZeroWindowProbes == 0) { - return false - } - } - // Set new timeout. The timer will be restarted by the call to sendData // below. s.rto *= 2 + // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 + if s.rto > s.maxRTO { + s.rto = s.maxRTO + } // Cap RTO to remaining time. if s.rto > remaining { @@ -565,9 +577,20 @@ func (s *sender) retransmitTimerExpired() bool { // send. if s.zeroWindowProbing { s.sendZeroWindowProbe() + // RFC 1122 4.2.2.17: A TCP MAY keep its offered receive window closed + // indefinitely. As long as the receiving TCP continues to send + // acknowledgments in response to the probe segments, the sending TCP + // MUST allow the connection to stay open. return true } + seg := s.writeNext + // RFC 1122 4.2.3.5: Close the connection when the number of + // retransmissions for this segment is beyond a limit. + if seg != nil && seg.xmitCount > s.maxRetries { + return false + } + s.sendData() return true diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index d2c90ebd5..6ef32a1b3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2994,6 +2994,101 @@ func TestSendOnResetConnection(t *testing.T) { } } +// TestMaxRetransmitsTimeout tests if the connection is timed out after +// a segment has been retransmitted MaxRetries times. +func TestMaxRetransmitsTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const numRetries = 2 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { + t.Fatalf("could not set protocol option MaxRetries.\n") + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Expect first transmit and MaxRetries retransmits. + for i := 0; i < numRetries+1; i++ { + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + ), + ) + } + // Wait for the connection to timeout after MaxRetries retransmits. + initRTO := 1 * time.Second + select { + case <-notifyCh: + case <-time.After((2 << numRetries) * initRTO): + t.Fatalf("connection still alive after maximum retransmits.\n") + } + + // Send an ACK and expect a RST as the connection would have been closed. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } +} + +// TestMaxRTO tests if the retransmit interval caps to MaxRTO. +func TestMaxRTO(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + rto := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + const numRetransmits = 2 + for i := 0; i < numRetransmits; i++ { + start := time.Now() + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { + t.Errorf("Retransmit interval not capped to MaxRTO.\n") + } + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5774,7 +5869,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Invoke the moderation API. This is required for auto-tuning // to happen. This method is normally expected to be invoked // from a higher layer than tcpip.Endpoint. So we simulate - // copying to user-space by invoking it explicitly here. + // copying to userspace by invoking it explicitly here. c.EP.ModerateRecvBuf(totalCopied) // Now send a keep-alive packet to trigger an ACK so that we can @@ -6605,9 +6700,16 @@ func TestTCPUserTimeout(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - userTimeout := 50 * time.Millisecond + // Ensure that on the next retransmit timer fire, the user timeout has + // expired. + initRTO := 1 * time.Second + userTimeout := initRTO / 2 c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) // Send some data and wait before ACKing it. @@ -6627,9 +6729,13 @@ func TestTCPUserTimeout(t *testing.T) { ), ) - // Wait for a little over the minimum retransmit timeout of 200ms for - // the retransmitTimer to fire and close the connection. - time.Sleep(tcp.MinRTO + 10*time.Millisecond) + // Wait for the retransmit timer to be fired and the user timeout to cause + // close of the connection. + select { + case <-notifyCh: + case <-time.After(2 * initRTO): + t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + } // No packet should be received as the connection should be silently // closed due to timeout. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..647b2067a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -106,6 +106,9 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID @@ -188,6 +191,15 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + err := e.lastError + e.lastError = nil + return err +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -243,6 +255,10 @@ func (e *endpoint) IPTables() (stack.IPTables, error) { // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + e.rcvMu.Lock() if e.rcvList.Empty() { @@ -382,6 +398,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return 0, nil, err + } + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -853,6 +873,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: + return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -1316,6 +1337,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { + if typ == stack.ControlPortUnreachable { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state == StateConnected { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + e.lastError = tcpip.ErrConnectionRefused + } + } } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 466bd9381..851e6b635 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,6 +37,24 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = tcpip.StringToError(s) +} + // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 5f2af9f3b..c45d2ecbc 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -148,6 +148,62 @@ func (m MountMode) String() string { panic(fmt.Sprintf("invalid mode: %d", m)) } +// DockerNetwork contains the name of a docker network. +type DockerNetwork struct { + logger testutil.Logger + Name string + Subnet *net.IPNet + containers []*Docker +} + +// NewDockerNetwork sets up the struct for a Docker network. Names of networks +// will be unique. +func NewDockerNetwork(logger testutil.Logger) *DockerNetwork { + return &DockerNetwork{ + logger: logger, + Name: testutil.RandomID(logger.Name()), + } +} + +// Create calls 'docker network create'. +func (n *DockerNetwork) Create(args ...string) error { + a := []string{"docker", "network", "create"} + if n.Subnet != nil { + a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet)) + } + a = append(a, args...) + a = append(a, n.Name) + return testutil.Command(n.logger, a...).Run() +} + +// Connect calls 'docker network connect' with the arguments provided. +func (n *DockerNetwork) Connect(container *Docker, args ...string) error { + a := []string{"docker", "network", "connect"} + a = append(a, args...) + a = append(a, n.Name, container.Name) + if err := testutil.Command(n.logger, a...).Run(); err != nil { + return err + } + n.containers = append(n.containers, container) + return nil +} + +// Cleanup cleans up the docker network and all the containers attached to it. +func (n *DockerNetwork) Cleanup() error { + for _, c := range n.containers { + // Don't propagate the error, it might be that the container + // was already cleaned up. + if err := c.Kill(); err != nil { + n.logger.Logf("unable to kill container during cleanup: %s", err) + } + } + + if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil { + return err + } + return nil +} + // Docker contains the name and the runtime of a docker container. type Docker struct { logger testutil.Logger @@ -162,9 +218,13 @@ type Docker struct { // // Names of containers will be unique. func MakeDocker(logger testutil.Logger) *Docker { + // Slashes are not allowed in container names. + name := testutil.RandomID(logger.Name()) + name = strings.ReplaceAll(name, "/", "-") + return &Docker{ logger: logger, - Name: testutil.RandomID(logger.Name()), + Name: name, Runtime: *runtime, } } @@ -309,7 +369,9 @@ func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) { rv = append(rv, d.Name) } else { rv = append(rv, d.mounts...) - rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) + if len(d.Runtime) > 0 { + rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) + } rv = append(rv, fmt.Sprintf("--name=%s", d.Name)) rv = append(rv, testutil.ImageByName(r.Image)) } @@ -477,6 +539,56 @@ func (d *Docker) FindIP() (net.IP, error) { return ip, nil } +// A NetworkInterface is container's network interface information. +type NetworkInterface struct { + IPv4 net.IP + MAC net.HardwareAddr +} + +// ListNetworks returns the network interfaces of the container, keyed by +// Docker network name. +func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) { + const format = `{{json .NetworkSettings.Networks}}` + out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() + if err != nil { + return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err) + } + + networks := map[string]map[string]string{} + if err := json.Unmarshal(out, &networks); err != nil { + return nil, fmt.Errorf("error decoding network interfaces: %w", err) + } + + interfaces := map[string]NetworkInterface{} + for name, iface := range networks { + var netface NetworkInterface + + rawIP := strings.TrimSpace(iface["IPAddress"]) + if rawIP != "" { + ip := net.ParseIP(rawIP) + if ip == nil { + return nil, fmt.Errorf("invalid IP: %q", rawIP) + } + // Docker's IPAddress field is IPv4. The IPv6 address + // is stored in the GlobalIPv6Address field. + netface.IPv4 = ip + } + + rawMAC := strings.TrimSpace(iface["MacAddress"]) + if rawMAC != "" { + mac, err := net.ParseMAC(rawMAC) + if err != nil { + return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err) + } + netface.MAC = mac + } + + interfaces[name] = netface + } + + return interfaces, nil +} + // SandboxPid returns the PID to the sandbox process. func (d *Docker) SandboxPid() (int, error) { out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput() diff --git a/pkg/usermem/addr.go b/pkg/usermem/addr.go index e79210804..c4100481e 100644 --- a/pkg/usermem/addr.go +++ b/pkg/usermem/addr.go @@ -106,3 +106,20 @@ func (ar AddrRange) IsPageAligned() bool { func (ar AddrRange) String() string { return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End) } + +// PageRoundDown/Up are equivalent to Addr.RoundDown/Up, but without the +// potentially truncating conversion from uint64 to Addr. This is necessary +// because there is no way to define generic "PageRoundDown/Up" functions in Go. + +// PageRoundDown returns x rounded down to the nearest page boundary. +func PageRoundDown(x uint64) uint64 { + return x &^ (PageSize - 1) +} + +// PageRoundUp returns x rounded up to the nearest page boundary. +// ok is true iff rounding up did not wrap around. +func PageRoundUp(x uint64) (addr uint64, ok bool) { + addr = PageRoundDown(x + PageSize - 1) + ok = addr >= x + return +} |