diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/buffer/safemem.go | 82 | ||||
-rw-r--r-- | pkg/segment/BUILD | 2 | ||||
-rw-r--r-- | pkg/segment/set.go | 400 | ||||
-rw-r--r-- | pkg/segment/test/BUILD | 18 | ||||
-rw-r--r-- | pkg/segment/test/segment_test.go | 397 | ||||
-rw-r--r-- | pkg/segment/test/set_functions.go | 32 | ||||
-rw-r--r-- | pkg/sentry/fs/g3doc/.gitignore | 1 | ||||
-rw-r--r-- | pkg/sentry/fs/g3doc/fuse.md | 260 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/gofer/gofer.go | 48 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/tmpfs/regular_file.go | 2 | ||||
-rw-r--r-- | pkg/sentry/kernel/pipe/BUILD | 2 | ||||
-rw-r--r-- | pkg/sentry/kernel/pipe/pipe.go | 6 | ||||
-rw-r--r-- | pkg/sentry/kernel/pipe/pipe_unsafe.go | 35 | ||||
-rw-r--r-- | pkg/sentry/kernel/pipe/vfs.go | 219 | ||||
-rw-r--r-- | pkg/sentry/mm/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/mm/vma.go | 4 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/vfs2/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/vfs2/splice.go | 286 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/vfs2/vfs2.go | 4 | ||||
-rw-r--r-- | pkg/sentry/vfs/file_description.go | 5 | ||||
-rw-r--r-- | pkg/test/dockerutil/dockerutil.go | 116 |
21 files changed, 1801 insertions, 120 deletions
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go index 0e5b86344..b789e56e9 100644 --- a/pkg/buffer/safemem.go +++ b/pkg/buffer/safemem.go @@ -28,12 +28,11 @@ func (b *buffer) ReadBlock() safemem.Block { return safemem.BlockFromSafeSlice(b.ReadSlice()) } -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. -// -// This will advance the write index. -func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - need := int(srcs.NumBytes()) - if need == 0 { +// WriteFromSafememReader writes up to count bytes from r to v and advances the +// write index by the number of bytes written. It calls r.ReadToBlocks() at +// most once. +func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) { + if count == 0 { return 0, nil } @@ -50,32 +49,33 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { } // Does the last block have sufficient capacity alone? - if l := firstBuf.WriteSize(); l >= need { - dst = safemem.BlockSeqOf(firstBuf.WriteBlock()) + if l := uint64(firstBuf.WriteSize()); l >= count { + dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count)) } else { // Append blocks until sufficient. - need -= l + count -= l blocks = append(blocks, firstBuf.WriteBlock()) - for need > 0 { + for count > 0 { emptyBuf := bufferPool.Get().(*buffer) v.data.PushBack(emptyBuf) - need -= emptyBuf.WriteSize() - blocks = append(blocks, emptyBuf.WriteBlock()) + block := emptyBuf.WriteBlock().TakeFirst64(count) + count -= uint64(block.Len()) + blocks = append(blocks, block) } dst = safemem.BlockSeqFromSlice(blocks) } - // Perform the copy. - n, err := safemem.CopySeq(dst, srcs) + // Perform I/O. + n, err := r.ReadToBlocks(dst) v.size += int64(n) // Update all indices. - for left := int(n); left > 0; firstBuf = firstBuf.Next() { - if l := firstBuf.WriteSize(); left >= l { + for left := n; left > 0; firstBuf = firstBuf.Next() { + if l := firstBuf.WriteSize(); left >= uint64(l) { firstBuf.WriteMove(l) // Whole block. - left -= l + left -= uint64(l) } else { - firstBuf.WriteMove(left) // Partial block. + firstBuf.WriteMove(int(left)) // Partial block. left = 0 } } @@ -83,14 +83,16 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } -// ReadToBlocks implements safemem.Reader.ReadToBlocks. -// -// This will not advance the read index; the caller should follow -// this call with a call to TrimFront in order to remove the read -// data from the buffer. This is done to support pipe sematics. -func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - need := int(dsts.NumBytes()) - if need == 0 { +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the +// write index by the number of bytes written. +func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes()) +} + +// ReadToSafememWriter reads up to count bytes from v to w. It does not advance +// the read index. It calls w.WriteFromBlocks() at most once. +func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) { + if count == 0 { return 0, nil } @@ -105,25 +107,27 @@ func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { } // Is all the data in a single block? - if l := firstBuf.ReadSize(); l >= need { - src = safemem.BlockSeqOf(firstBuf.ReadBlock()) + if l := uint64(firstBuf.ReadSize()); l >= count { + src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count)) } else { // Build a list of all the buffers. - need -= l + count -= l blocks = append(blocks, firstBuf.ReadBlock()) - for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() { - need -= buf.ReadSize() - blocks = append(blocks, buf.ReadBlock()) + for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() { + block := buf.ReadBlock().TakeFirst64(count) + count -= uint64(block.Len()) + blocks = append(blocks, block) } src = safemem.BlockSeqFromSlice(blocks) } - // Perform the copy. - n, err := safemem.CopySeq(dsts, src) - - // See above: we would normally advance the read index here, but we - // don't do that in order to support pipe semantics. We rely on a - // separate call to TrimFront() in this case. + // Perform I/O. As documented, we don't advance the read index. + return w.WriteFromBlocks(src) +} - return n, err +// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the +// read index by the number of bytes read, such that it's only safe to call if +// the caller guarantees that ReadToBlocks will only be called once. +func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes()) } diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD index 1b487b887..f57ccc170 100644 --- a/pkg/segment/BUILD +++ b/pkg/segment/BUILD @@ -21,6 +21,8 @@ go_template( ], opt_consts = [ "minDegree", + # trackGaps must either be 0 or 1. + "trackGaps", ], types = [ "Key", diff --git a/pkg/segment/set.go b/pkg/segment/set.go index 03e4f258f..1a17ad9cb 100644 --- a/pkg/segment/set.go +++ b/pkg/segment/set.go @@ -36,6 +36,34 @@ type Range interface{} // Value is a required type parameter. type Value interface{} +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const trackGaps = 0 + +var _ = uint8(trackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type dynamicGap [trackGaps]Key + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Get() Key { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Set(v Key) { + d[:][0] = v +} + // Functions is a required type parameter that must be a struct implementing // the methods defined by Functions. type Functions interface { @@ -327,8 +355,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if prev.Ok() && prev.End() == r.Start { if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() prev.SetEndUnchecked(r.End) prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } if next.Ok() && next.Start() == r.End { val = mval if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { @@ -342,11 +374,16 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if next.Ok() && next.Start() == r.End { if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() next.SetStartUnchecked(r.Start) next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } return next } } + // InsertWithoutMergingUnchecked will maintain maxGap if necessary. return s.InsertWithoutMergingUnchecked(gap, r, val) } @@ -373,11 +410,15 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) gap.node.keys[gap.index] = r gap.node.values[gap.index] = val gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } return Iterator{gap.node, gap.index} } @@ -399,12 +440,23 @@ func (s *Set) Remove(seg Iterator) GapIterator { // overlap. seg.SetRangeUnchecked(victim.Range()) seg.SetValue(victim.Value()) + // Need to update the nextAdjacentNode's maxGap because the gap in between + // must have been modified by updating seg.Range() to victim.Range(). + // seg.NextSegment() must exist since the last segment can't be in a + // non-leaf node. + nextAdjacentNode := seg.NextSegment().node + if trackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } return s.Remove(victim).NextGap() } copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) seg.node.nrSegments-- + if trackGaps != 0 { + seg.node.updateMaxGapLeaf() + } return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index}) } @@ -455,6 +507,7 @@ func (s *Set) MergeUnchecked(first, second Iterator) Iterator { // overlaps second. first.SetEndUnchecked(second.End()) first.SetValue(mval) + // Remove will handle the maxGap update if necessary. return s.Remove(second).PrevSegment() } } @@ -631,6 +684,12 @@ type node struct { // than "isLeaf" because false must be the correct value for an empty root. hasChildren bool + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap dynamicGap + // Nodes store keys and values in separate arrays to maximize locality in // the common case (scanning keys for lookup). keys [maxDegree - 1]Range @@ -676,12 +735,12 @@ func (n *node) nextSibling() *node { // required for insertion, and returns an updated iterator to the position // represented by gap. func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { - if n.parent != nil { - gap = n.parent.rebalanceBeforeInsert(gap) - } if n.nrSegments < maxDegree-1 { return gap } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } if n.parent == nil { // n is root. Move all segments before and after n's median segment // into new child nodes adjacent to the median segment, which is now @@ -719,6 +778,13 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { n.hasChildren = true n.children[0] = left n.children[1] = right + // In this case, n's maxGap won't violated as it's still the root, + // but the left and right children should be updated locally as they + // are newly split from n. + if trackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } if gap.node != n { return gap } @@ -758,6 +824,12 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { } } n.nrSegments = minDegree - 1 + // MaxGap of n's parent is not violated because the segments within is not changed. + // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } // gap.node can't be n.parent because gaps are always in leaf nodes. if gap.node != n { return gap @@ -821,6 +893,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling && gap.index == sibling.nrSegments { return GapIterator{n, 0} } @@ -849,6 +927,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling { if gap.index == 0 { return GapIterator{n, n.nrSegments} @@ -886,6 +970,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { p.children[0] = nil p.children[1] = nil } + // No need to update maxGap of p as its content is not changed. if gap.node == left { return GapIterator{p, gap.index} } @@ -932,11 +1017,152 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } p.children[p.nrSegments] = nil p.nrSegments-- + // Update maxGap of left locally, no need to change p and right because + // p's contents is not changed and right is already invalid. + if trackGaps != 0 { + left.updateMaxGapLocal() + } // This process robs p of one segment, so recurse into rebalancing p. n = p } } +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *node) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + // If new max equals the old maxGap, no update is needed. + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + // Grow ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + // p and its ancestors already contain an equal or larger gap. + break + } + // Only if new maxGap is larger than parent's + // old maxGap, propagate this update to parent. + p.maxGap.Set(max) + } + return + } + // Shrink ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + // p and its ancestors still contain a larger gap. + break + } + // If new max is smaller than the old maxGap, and this gap used + // to be the maxGap of its parent, iterate parent's children + // and calculate parent's new maxGap.(It's probable that parent + // has two children with the old maxGap, but we need to check it anyway.) + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + // p and its ancestors still contain a gap of at least equal size. + break + } + // If p's new maxGap differs from the old one, propagate this update. + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *node) updateMaxGapLocal() { + if !n.hasChildren { + // Leaf node iterates its gaps. + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + // Non-leaf node iterates its children. + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *node) calculateMaxGapLeaf() Key { + max := GapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (GapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *node) calculateMaxGapInternal() Key { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + // A Iterator is conceptually one of: // // - A pointer to a segment in a set; or @@ -1243,6 +1469,122 @@ func (gap GapIterator) NextGap() GapIterator { return seg.NextGap() } +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the trailing one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate subsequent gaps. + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + // If gap is the first gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the first one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate previous gaps. + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + // If gap is the first gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + // segmentBeforePosition returns the predecessor segment of the position given // by n.children[i], which may or may not contain a child. If no such segment // exists, segmentBeforePosition returns a terminal iterator. @@ -1271,7 +1613,7 @@ func segmentAfterPosition(n *node, i int) Iterator { func zeroValueSlice(slice []Value) { // TODO(jamieliu): check if Go is actually smart enough to optimize a - // ClearValue that assigns nil to a memset here + // ClearValue that assigns nil to a memset here. for i := range slice { Functions{}.ClearValue(&slice[i]) } @@ -1310,7 +1652,15 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) } buf.WriteString(prefix) - buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + if n.hasChildren { + if trackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } } if child := n.children[n.nrSegments]; child != nil { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) @@ -1362,3 +1712,43 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { } return nil } + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error { + havePrev := false + prev := Key(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *Set) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index f2d8462d8..131bf09b9 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -29,10 +29,28 @@ go_template_instance( }, ) +go_template_instance( + name = "gap_set", + out = "gap_set.go", + consts = { + "trackGaps": "1", + }, + package = "segment", + prefix = "gap", + template = "//pkg/segment:generic_set", + types = { + "Key": "int", + "Range": "Range", + "Value": "int", + "Functions": "gapSetFunctions", + }, +) + go_library( name = "segment", testonly = 1, srcs = [ + "gap_set.go", "int_range.go", "int_set.go", "set_functions.go", diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go index 97b16c158..85fa19096 100644 --- a/pkg/segment/test/segment_test.go +++ b/pkg/segment/test/segment_test.go @@ -17,6 +17,7 @@ package segment import ( "fmt" "math/rand" + "reflect" "testing" ) @@ -32,61 +33,65 @@ const ( // valueOffset is the difference between the value and start of test // segments. valueOffset = 100000 + + // intervalLength is the interval used by random gap tests. + intervalLength = 10 ) func shuffle(xs []int) { - for i := range xs { - j := rand.Intn(i + 1) - xs[i], xs[j] = xs[j], xs[i] - } + rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] }) } -func randPermutation(size int) []int { +func randIntervalPermutation(size int) []int { p := make([]int, size) for i := range p { - p[i] = i + p[i] = intervalLength * i } shuffle(p) return p } -// checkSet returns an error if s is incorrectly sorted, does not contain -// exactly expectedSegments segments, or contains a segment for which val != -// key + valueOffset. -func checkSet(s *Set, expectedSegments int) error { - havePrev := false - prev := 0 - nrSegments := 0 - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - next := seg.Start() - if havePrev && prev >= next { - return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) - } - if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want) - } - prev = next - havePrev = true - nrSegments++ - } - if nrSegments != expectedSegments { - return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) +// validate can be passed to Check. +func validate(nr int, r Range, v int) error { + if got, want := v, r.Start+valueOffset; got != want { + return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want) } return nil } -// countSegmentsIn returns the number of segments in s. -func countSegmentsIn(s *Set) int { - var count int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ +// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well +// maintained. +func checkSetMaxGap(s *gapSet) error { + n := s.root + return checkNodeMaxGap(&n) +} + +// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is +// not well maintained. +func checkNodeMaxGap(n *gapnode) error { + var max int + if !n.hasChildren { + max = n.calculateMaxGapLeaf() + } else { + for i := 0; i <= n.nrSegments; i++ { + child := n.children[i] + if err := checkNodeMaxGap(child); err != nil { + return err + } + if temp := child.maxGap.Get(); i == 0 || temp > max { + max = temp + } + } + } + if max != n.maxGap.Get() { + return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap) } - return count + return nil } func TestAddRandom(t *testing.T) { var s Set - order := randPermutation(testSize) + order := rand.Perm(testSize) var nrInsertions int for i, j := range order { if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { @@ -94,12 +99,12 @@ func TestAddRandom(t *testing.T) { break } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -115,7 +120,156 @@ func TestRemoveRandom(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(testSize) + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapAddRandom(t *testing.T) { + var s gapSet + order := rand.Perm(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithRandomInterval(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + nrInsertions := 1 + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapRemoveRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapRemoveHalfRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := randIntervalPermutation(testSize) + order = order[:testSize/2] var nrRemovals int for i, j := range order { seg := s.FindSegment(j) @@ -123,14 +277,19 @@ func TestRemoveRandom(t *testing.T) { t.Errorf("Iteration %d: failed to find segment with key %d", i, j) break } + temprange := seg.Range() s.Remove(seg) nrRemovals++ - if err := checkSet(&s, testSize-nrRemovals); err != nil { + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } } - if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want { + if got, want := s.countSegments(), testSize-nrRemovals; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -140,6 +299,148 @@ func TestRemoveRandom(t *testing.T) { } } +func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + var nrRemovals int + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestNextLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestPrevLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + end := s.LastSegment().End() + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + func TestAddSequentialAdjacent(t *testing.T) { var s Set var nrInsertions int @@ -148,12 +449,12 @@ func TestAddSequentialAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -202,12 +503,12 @@ func TestAddSequentialNonAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -293,7 +594,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -351,7 +652,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -378,7 +679,7 @@ func benchmarkAddSequential(b *testing.B, size int) { } func benchmarkAddRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -416,7 +717,7 @@ func benchmarkFindRandom(b *testing.B, size int) { b.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -470,7 +771,7 @@ func benchmarkAddFindRemoveSequential(b *testing.B, size int) { } func benchmarkAddFindRemoveRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go index bcddb39bb..7cd895cc7 100644 --- a/pkg/segment/test/set_functions.go +++ b/pkg/segment/test/set_functions.go @@ -14,21 +14,16 @@ package segment -// Basic numeric constants that we define because the math package doesn't. -// TODO(nlacasse): These should be Math.MaxInt64/MinInt64? -const ( - maxInt = int(^uint(0) >> 1) - minInt = -maxInt - 1 -) - type setFunctions struct{} -func (setFunctions) MinKey() int { - return minInt +// MinKey returns the minimum key for the set. +func (s setFunctions) MinKey() int { + return -s.MaxKey() - 1 } +// MaxKey returns the maximum key for the set. func (setFunctions) MaxKey() int { - return maxInt + return int(^uint(0) >> 1) } func (setFunctions) ClearValue(*int) {} @@ -40,3 +35,20 @@ func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) { func (setFunctions) Split(_ Range, val int, _ int) (int, int) { return val, val } + +type gapSetFunctions struct { + setFunctions +} + +// MinKey is adjusted to make sure no add overflow would happen in test cases. +// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length(). +// +// Normally Keys should be unsigned to avoid these issues. +func (s gapSetFunctions) MinKey() int { + return s.setFunctions.MinKey() / 2 +} + +// MaxKey returns the maximum key for the set. +func (s gapSetFunctions) MaxKey() int { + return s.setFunctions.MaxKey() / 2 +} diff --git a/pkg/sentry/fs/g3doc/.gitignore b/pkg/sentry/fs/g3doc/.gitignore new file mode 100644 index 000000000..2d19fc766 --- /dev/null +++ b/pkg/sentry/fs/g3doc/.gitignore @@ -0,0 +1 @@ +*.html diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md new file mode 100644 index 000000000..635cc009b --- /dev/null +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -0,0 +1,260 @@ +# Foreword + +This document describes an on-going project to support FUSE filesystems within +the sentry. This is intended to become the final documentation for this +subsystem, and is therefore written in the past tense. However FUSE support is +currently incomplete and the document will be updated as things progress. + +# FUSE: Filesystem in Userspace + +The sentry supports dispatching filesystem operations to a FUSE server, allowing +FUSE filesystem to be used with a sandbox. + +## Overview + +FUSE has two main components: + +1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards + filesystem operations (usually initiated by syscalls) to the server. + +2. A server, which is a userspace daemon that implements the actual filesystem. + +The sentry implements the client component, which allows a server daemon running +within the sandbox to implement a filesystem within the sandbox. + +A FUSE filesystem is initialized with `mount(2)`, typically with the help of a +utility like `fusermount(1)`. Various mount options exist for establishing +ownership and access permissions on the filesystem, but the most important mount +option is a file descriptor used to establish communication between the client +and server. + +The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation, +the client and server use the FUSE protocol described in `fuse(4)` to service +filesystem operations. See the "Protocol" section below for more information +about this protocol. The core of the sentry support for FUSE is the client-side +implementation of this protocol. + +## FUSE in the Sentry + +The sentry's FUSE client targets VFS2 and has the following components: + +- An implementation of `/dev/fuse`. + +- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting + VFS2, one point of contention may be the lack of inodes in VFS2. We can + tentatively implement a kernfs-based filesystem to bridge the gap in APIs. + The kernfs base functionality can serve the role of the Linux inode cache + and, the filesystem can map VFS2 syscalls to kernfs inode operations; see + the `kernfs.Inode` interface. + +The FUSE protocol lends itself well to marshaling with `go_marshal`. The various +request and response packets can be defined in the ABI package and converted to +and from the wire format using `go_marshal`. + +### Design Goals + +- While filesystem performance is always important, the sentry's FUSE support + is primarily concerned with compatibility, with performance as a secondary + concern. + +- Avoiding deadlocks from a hung server daemon. + +- Consider the potential for denial of service from a malicious server daemon. + Protecting itself from userspace is already a design goal for the sentry, + but needs additional consideration for FUSE. Normally, an operating system + doesn't rely on userspace to make progress with filesystem operations. Since + this changes with FUSE, it opens up the possibility of creating a chain of + dependencies controlled by userspace, which could affect an entire sandbox. + For example: a FUSE op can block a syscall, which could be holding a + subsystem lock, which can then block another task goroutine. + +### Milestones + +Below are some broad goals to aim for while implementing FUSE in the sentry. +Many FUSE ops can be grouped into broad categories of functionality, and most +ops can be implemented in parallel. + +#### Minimal client that can mount a trivial FUSE filesystem. + +- Implement `/dev/fuse`. + +- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. + +#### Read-only mount with basic file operations + +- Implement the majority of file, directory and file descriptor FUSE ops. For + this milestone, we can skip uncommon or complex operations like mmap, mknod, + file locking, poll, and extended attributes. We can stub these out along + with any ops that modify the filesystem. The exact list of required ops are + to be determined, but the goal is to mount a real filesystem as read-only, + and be able to read contents from the filesystem in the sentry. + +#### Full read-write support + +- Implement the remaining FUSE ops and decide if we can omit rarely used + operations like ioctl. + +# Appendix + +## FUSE Protocol + +The FUSE protocol is a request-response protocol. All requests are initiated by +the client. The wire-format for the protocol is raw c structs serialized to +memory. + +All FUSE requests begin with the following request header: + +```c +struct fuse_in_header { + uint32_t len; // Length of the request, including this header. + uint32_t opcode; // Requested operation. + uint64_t unique; // A unique identifier for this request. + uint64_t nodeid; // ID of the filesystem object being operated on. + uint32_t uid; // UID of the requesting process. + uint32_t gid; // GID of the requesting process. + uint32_t pid; // PID of the requesting process. + uint32_t padding; +}; +``` + +The request is then followed by a payload specific to the `opcode`. + +All responses begin with this response header: + +```c +struct fuse_out_header { + uint32_t len; // Length of the response, including this header. + int32_t error; // Status of the request, 0 if success. + uint64_t unique; // The unique identifier from the corresponding request. +}; +``` + +The response payload also depends on the request `opcode`. If `error != 0`, the +response payload must be empty. + +### Operations + +The following is a list of all FUSE operations used in `fuse_in_header.opcode` +as of Linux v4.4, and a brief description of their purpose. These are defined in +`uapi/linux/fuse.h`. Many of these have a corresponding request and response +payload struct; `fuse(4)` has details for some of these. We also note how these +operations map to the sentry virtual filesystem. + +#### FUSE meta-operations + +These operations are specific to FUSE and don't have a corresponding action in a +generic filesystem. + +- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the + first message sent by the client after mount. This is used for version and + feature negotiation. This is related to `mount(2)`. +- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`. +- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the + `fuse_in_header.unique` value provided in the corresponding request header. + The client can send at most one of these per request, and will enter an + uninterruptible wait for a reply. The server is expected to reply promptly. +- `FUSE_FORGET`: A hint to the server that server should evict the indicate + node from any caches. This is wired up to `(struct + super_operations).evict_inode` in Linux, which is in turned hooked as the + inode cache shrinker which is typically triggered by system memory pressure. +- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`. + +#### Filesystem Syscalls + +These FUSE ops map directly to an equivalent filesystem syscall, or family of +syscalls. The relevant syscalls have a similar name to the operation, unless +otherwise noted. + +Node creation: + +- `FUSE_MKNOD` +- `FUSE_MKDIR` +- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which + atomically creates and opens a node. + +Node attributes and extended attributes: + +- `FUSE_GETATTR` +- `FUSE_SETATTR` +- `FUSE_SETXATTR` +- `FUSE_GETXATTR` +- `FUSE_LISTXATTR` +- `FUSE_REMOVEXATTR` + +Node link manipulation: + +- `FUSE_READLINK` +- `FUSE_LINK` +- `FUSE_SYMLINK` +- `FUSE_UNLINK` + +Directory operations: + +- `FUSE_RMDIR` +- `FUSE_RENAME` +- `FUSE_RENAME2` +- `FUSE_OPENDIR`: `open(2)` for directories. +- `FUSE_RELEASEDIR`: `close(2)` for directories. +- `FUSE_READDIR` +- `FUSE_READDIRPLUS` +- `FUSE_FSYNCDIR`: `fsync(2)` for directories. +- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is + reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path + component to a node. However the returned identifier is opaque to the + client. The server must remember this mapping, as this is how the client + will reference the node in the future. + +File operations: + +- `FUSE_OPEN`: `open(2)` for files. +- `FUSE_RELEASE`: `close(2)` for files. +- `FUSE_FSYNC` +- `FUSE_FALLOCATE` +- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`. +- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`. + +File locking: + +- `FUSE_GETLK` +- `FUSE_SETLK` +- `FUSE_SETLKW` +- `FUSE_COPY_FILE_RANGE` + +File descriptor operations: + +- `FUSE_IOCTL` +- `FUSE_POLL` +- `FUSE_LSEEK` + +Filesystem operations: + +- `FUSE_STATFS` + +#### Permissions + +- `FUSE_ACCESS` is used to check if a node is accessible, as part of many + syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the + sentry. + +#### I/O Operations + +These ops are used to read and write file pages. They're used to implement both +I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`. + +- `FUSE_READ` +- `FUSE_WRITE` + +#### Miscellaneous + +- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is + closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)` + syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the + sentry. +- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed. +- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?] + +# References + +- `fuse(4)` manpage. +- Linux kernel FUSE documentation: + https://www.kernel.org/doc/html/latest/filesystems/fuse.html diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 6295f6b54..131da332f 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -84,12 +84,6 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 - // uid and gid are the effective KUID and KGID of the filesystem's creator, - // and are used as the owner and group for files that don't specify one. - // uid and gid are immutable. - uid auth.KUID - gid auth.KGID - // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -122,6 +116,8 @@ type filesystemOptions struct { fd int aname string interop InteropMode // derived from the "cache" mount option + dfltuid auth.KUID + dfltgid auth.KGID msize uint32 version string @@ -230,6 +226,15 @@ type InternalFilesystemOptions struct { OpenSocketsByConnecting bool } +// _V9FS_DEFUID and _V9FS_DEFGID (from Linux's fs/9p/v9fs.h) are the default +// UIDs and GIDs used for files that do not provide a specific owner or group +// respectively. +const ( + // uint32(-2) doesn't work in Go. + _V9FS_DEFUID = auth.KUID(4294967294) + _V9FS_DEFGID = auth.KGID(4294967294) +) + // Name implements vfs.FilesystemType.Name. func (FilesystemType) Name() string { return Name @@ -315,6 +320,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } + // Parse the default UID and GID. + fsopts.dfltuid = _V9FS_DEFUID + if dfltuidstr, ok := mopts["dfltuid"]; ok { + delete(mopts, "dfltuid") + dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr) + return nil, nil, syserror.EINVAL + } + // In Linux, dfltuid is interpreted as a UID and is converted to a KUID + // in the caller's user namespace, but goferfs isn't + // application-mountable. + fsopts.dfltuid = auth.KUID(dfltuid) + } + fsopts.dfltgid = _V9FS_DEFGID + if dfltgidstr, ok := mopts["dfltgid"]; ok { + delete(mopts, "dfltgid") + dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr) + return nil, nil, syserror.EINVAL + } + fsopts.dfltgid = auth.KGID(dfltgid) + } + // Parse the 9P message size. fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M if msizestr, ok := mopts["msize"]; ok { @@ -422,8 +452,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt client: client, clock: ktime.RealtimeClockFromContext(ctx), devMinor: devMinor, - uid: creds.EffectiveKUID, - gid: creds.EffectiveKGID, syncableDentries: make(map[*dentry]struct{}), specialFileFDs: make(map[*specialFileFD]struct{}), } @@ -672,8 +700,8 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma file: file, ino: qid.Path, mode: uint32(attr.Mode), - uid: uint32(fs.uid), - gid: uint32(fs.gid), + uid: uint32(fs.opts.dfltuid), + gid: uint32(fs.opts.dfltgid), blockSize: usermem.PageSize, handle: handle{ fd: -1, diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 3f433d666..fee174375 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -312,7 +312,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off f := fd.inode().impl.(*regularFile) if end := offset + srclen; end < offset { // Overflow. - return 0, syserror.EFBIG + return 0, syserror.EINVAL } var err error diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index f29dc0472..7bfa9075a 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -8,6 +8,7 @@ go_library( "device.go", "node.go", "pipe.go", + "pipe_unsafe.go", "pipe_util.go", "reader.go", "reader_writer.go", @@ -20,6 +21,7 @@ go_library( "//pkg/amutex", "//pkg/buffer", "//pkg/context", + "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 62c8691f1..79645d7d2 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -207,7 +207,10 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() + return p.readLocked(ctx, ops) +} +func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { // Is the pipe empty? if p.view.Size() == 0 { if !p.HasWriters() { @@ -246,7 +249,10 @@ type writeOps struct { func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() + return p.writeLocked(ctx, ops) +} +func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { // Can't write to a pipe with no readers. if !p.HasReaders() { return 0, syscall.EPIPE diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go new file mode 100644 index 000000000..dd60cba24 --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "unsafe" +) + +// lockTwoPipes locks both x.mu and y.mu in an order that is guaranteed to be +// consistent for both lockTwoPipes(x, y) and lockTwoPipes(y, x), such that +// concurrent calls cannot deadlock. +// +// Preconditions: x != y. +func lockTwoPipes(x, y *Pipe) { + // Lock the two pipes in order of increasing address. + if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) { + x.mu.Lock() + y.mu.Lock() + } else { + y.mu.Lock() + x.mu.Lock() + } +} diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index b54f08a30..2602bed72 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -16,7 +16,9 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +152,9 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) * return &fd.vfsfd } -// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. +// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements +// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to +// other FileDescriptions for splice(2) and tee(2). type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -229,3 +233,216 @@ func (fd *VFSPipeFD) PipeSize() int64 { func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { return fd.pipe.SetFifoSize(size) } + +// IOSequence returns a useremm.IOSequence that reads up to count bytes from, +// or writes up to count bytes to, fd. +func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence { + return usermem.IOSequence{ + IO: fd, + Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), + } +} + +// CopyIn implements usermem.IO.CopyIn. +func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { + origCount := int64(len(dst)) + n, err := fd.pipe.read(ctx, readOps{ + left: func() int64 { + return int64(len(dst)) + }, + limit: func(l int64) { + dst = dst[:l] + }, + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadAt(dst, 0) + view.TrimFront(int64(n)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventOut) + } + if err == nil && n != origCount { + return int(n), syserror.ErrWouldBlock + } + return int(n), err +} + +// CopyOut implements usermem.IO.CopyOut. +func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { + origCount := int64(len(src)) + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return int64(len(src)) + }, + limit: func(l int64) { + src = src[:l] + }, + write: func(view *buffer.View) (int64, error) { + view.Append(src) + return int64(len(src)), nil + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return int(n), syserror.ErrWouldBlock + } + return int(n), err +} + +// ZeroOut implements usermem.IO.ZeroOut. +func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { + origCount := toZero + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return toZero + }, + limit: func(l int64) { + toZero = l + }, + write: func(view *buffer.View) (int64, error) { + view.Grow(view.Size()+toZero, true /* zero */) + return toZero, nil + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// CopyInTo implements usermem.IO.CopyInTo. +func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { + count := ars.NumBytes() + if count == 0 { + return 0, nil + } + origCount := count + n, err := fd.pipe.read(ctx, readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadToSafememWriter(dst, uint64(count)) + view.TrimFront(int64(n)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventOut) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// CopyOutFrom implements usermem.IO.CopyOutFrom. +func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { + count := ars.NumBytes() + if count == 0 { + return 0, nil + } + origCount := count + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(view *buffer.View) (int64, error) { + n, err := view.WriteFromSafememReader(src, uint64(count)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// SwapUint32 implements usermem.IO.SwapUint32. +func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) { + // How did a pipe get passed as the virtual address space to futex(2)? + panic("VFSPipeFD.SwapUint32 called unexpectedly") +} + +// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32. +func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) { + panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly") +} + +// LoadUint32 implements usermem.IO.LoadUint32. +func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) { + panic("VFSPipeFD.LoadUint32 called unexpectedly") +} + +// Splice reads up to count bytes from src and writes them to dst. It returns +// the number of bytes moved. +// +// Preconditions: count > 0. +func Splice(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) { + return spliceOrTee(ctx, dst, src, count, true /* removeFromSrc */) +} + +// Tee reads up to count bytes from src and writes them to dst, without +// removing the read bytes from src. It returns the number of bytes copied. +// +// Preconditions: count > 0. +func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) { + return spliceOrTee(ctx, dst, src, count, false /* removeFromSrc */) +} + +// Preconditions: count > 0. +func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) { + if dst.pipe == src.pipe { + return 0, syserror.EINVAL + } + + lockTwoPipes(dst.pipe, src.pipe) + defer dst.pipe.mu.Unlock() + defer src.pipe.mu.Unlock() + + n, err := dst.pipe.writeLocked(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(dstView *buffer.View) (int64, error) { + return src.pipe.readLocked(ctx, readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(srcView *buffer.View) (int64, error) { + n, err := srcView.ReadToSafememWriter(dstView, uint64(count)) + if n > 0 && removeFromSrc { + srcView.TrimFront(int64(n)) + } + return int64(n), err + }, + }) + }, + }) + if n > 0 { + dst.pipe.Notify(waiter.EventIn) + src.pipe.Notify(waiter.EventOut) + } + return n, err +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 73591dab7..a036ce53c 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -25,6 +25,7 @@ go_template_instance( out = "vma_set.go", consts = { "minDegree": "8", + "trackGaps": "1", }, imports = { "usermem": "gvisor.dev/gvisor/pkg/usermem", diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 9a14e69e6..16d8207e9 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -195,7 +195,7 @@ func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange { // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextGap() { + for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift up to match the alignment? if offset := uint64(gr.Start) % alignment; offset != 0 { @@ -214,7 +214,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevGap() { + for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift down to match the alignment? start := gr.End - usermem.Addr(length) diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index f882ef840..d56927ff5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -22,6 +22,7 @@ go_library( "setstat.go", "signal.go", "socket.go", + "splice.go", "stat.go", "stat_amd64.go", "stat_arm64.go", diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go new file mode 100644 index 000000000..8f3c22a02 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -0,0 +1,286 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Splice implements Linux syscall splice(2). +func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + inFD := args[0].Int() + inOffsetPtr := args[1].Pointer() + outFD := args[2].Int() + outOffsetPtr := args[3].Pointer() + count := int64(args[4].SizeT()) + flags := args[5].Int() + + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } + + // Check for invalid flags. + if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get file descriptions. + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + + // Check that both files support the required directionality. + if !inFile.IsReadable() || !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0) + + // At least one file description must represent a pipe. + inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + if !inIsPipe && !outIsPipe { + return 0, nil, syserror.EINVAL + } + + // Copy in offsets. + inOffset := int64(-1) + if inOffsetPtr != 0 { + if inIsPipe { + return 0, nil, syserror.ESPIPE + } + if inFile.Options().DenyPRead { + return 0, nil, syserror.EINVAL + } + if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil { + return 0, nil, err + } + if inOffset < 0 { + return 0, nil, syserror.EINVAL + } + } + outOffset := int64(-1) + if outOffsetPtr != 0 { + if outIsPipe { + return 0, nil, syserror.ESPIPE + } + if outFile.Options().DenyPWrite { + return 0, nil, syserror.EINVAL + } + if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil { + return 0, nil, err + } + if outOffset < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Move data. + var ( + n int64 + err error + inCh chan struct{} + outCh chan struct{} + ) + for { + // If both input and output are pipes, delegate to the pipe + // implementation. Otherwise, exactly one end is a pipe, which we + // ensure is consistently ordered after the non-pipe FD's locks by + // passing the pipe FD as usermem.IO to the non-pipe end. + switch { + case inIsPipe && outIsPipe: + n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) + case inIsPipe: + if outOffset != -1 { + n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{}) + outOffset += n + } else { + n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{}) + } + case outIsPipe: + if inOffset != -1 { + n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{}) + inOffset += n + } else { + n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) + } + } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, eventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } + } + if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, eventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } + } + } + + // Copy updated offsets out. + if inOffsetPtr != 0 { + if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil { + return 0, nil, err + } + } + if outOffsetPtr != 0 { + if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + return uintptr(n), nil, nil +} + +// Tee implements Linux syscall tee(2). +func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + inFD := args[0].Int() + outFD := args[1].Int() + count := int64(args[2].SizeT()) + flags := args[3].Int() + + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } + + // Check for invalid flags. + if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get file descriptions. + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + + // Check that both files support the required directionality. + if !inFile.IsReadable() || !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0) + + // Both file descriptions must represent pipes. + inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + if !inIsPipe || !outIsPipe { + return 0, nil, syserror.EINVAL + } + + // Copy data. + var ( + inCh chan struct{} + outCh chan struct{} + ) + for { + n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 { + return uintptr(n), nil, nil + } + if err != syserror.ErrWouldBlock || nonBlock { + return 0, nil, err + } + + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the tee operation. + if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, eventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err := t.Block(inCh); err != nil { + return 0, nil, err + } + } + if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, eventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err := t.Block(outCh); err != nil { + return 0, nil, err + } + } + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index a332d01bd..083fdcf82 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -134,8 +134,8 @@ func Override() { s.Table[269] = syscalls.Supported("faccessat", Faccessat) s.Table[270] = syscalls.Supported("pselect", Pselect) s.Table[271] = syscalls.Supported("ppoll", Ppoll) - delete(s.Table, 275) // splice - delete(s.Table, 276) // tee + s.Table[275] = syscalls.Supported("splice", Splice) + s.Table[276] = syscalls.Supported("tee", Tee) s.Table[277] = syscalls.Supported("sync_file_range", SyncFileRange) s.Table[280] = syscalls.Supported("utimensat", Utimensat) s.Table[281] = syscalls.Supported("epoll_pwait", EpollPwait) diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index cfabd936c..bb294563d 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -210,6 +210,11 @@ func (fd *FileDescription) VirtualDentry() VirtualDentry { return fd.vd } +// Options returns the options passed to fd.Init(). +func (fd *FileDescription) Options() FileDescriptionOptions { + return fd.opts +} + // StatusFlags returns file description status flags, as for fcntl(F_GETFL). func (fd *FileDescription) StatusFlags() uint32 { return atomic.LoadUint32(&fd.statusFlags) diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 5f2af9f3b..c45d2ecbc 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -148,6 +148,62 @@ func (m MountMode) String() string { panic(fmt.Sprintf("invalid mode: %d", m)) } +// DockerNetwork contains the name of a docker network. +type DockerNetwork struct { + logger testutil.Logger + Name string + Subnet *net.IPNet + containers []*Docker +} + +// NewDockerNetwork sets up the struct for a Docker network. Names of networks +// will be unique. +func NewDockerNetwork(logger testutil.Logger) *DockerNetwork { + return &DockerNetwork{ + logger: logger, + Name: testutil.RandomID(logger.Name()), + } +} + +// Create calls 'docker network create'. +func (n *DockerNetwork) Create(args ...string) error { + a := []string{"docker", "network", "create"} + if n.Subnet != nil { + a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet)) + } + a = append(a, args...) + a = append(a, n.Name) + return testutil.Command(n.logger, a...).Run() +} + +// Connect calls 'docker network connect' with the arguments provided. +func (n *DockerNetwork) Connect(container *Docker, args ...string) error { + a := []string{"docker", "network", "connect"} + a = append(a, args...) + a = append(a, n.Name, container.Name) + if err := testutil.Command(n.logger, a...).Run(); err != nil { + return err + } + n.containers = append(n.containers, container) + return nil +} + +// Cleanup cleans up the docker network and all the containers attached to it. +func (n *DockerNetwork) Cleanup() error { + for _, c := range n.containers { + // Don't propagate the error, it might be that the container + // was already cleaned up. + if err := c.Kill(); err != nil { + n.logger.Logf("unable to kill container during cleanup: %s", err) + } + } + + if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil { + return err + } + return nil +} + // Docker contains the name and the runtime of a docker container. type Docker struct { logger testutil.Logger @@ -162,9 +218,13 @@ type Docker struct { // // Names of containers will be unique. func MakeDocker(logger testutil.Logger) *Docker { + // Slashes are not allowed in container names. + name := testutil.RandomID(logger.Name()) + name = strings.ReplaceAll(name, "/", "-") + return &Docker{ logger: logger, - Name: testutil.RandomID(logger.Name()), + Name: name, Runtime: *runtime, } } @@ -309,7 +369,9 @@ func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) { rv = append(rv, d.Name) } else { rv = append(rv, d.mounts...) - rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) + if len(d.Runtime) > 0 { + rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) + } rv = append(rv, fmt.Sprintf("--name=%s", d.Name)) rv = append(rv, testutil.ImageByName(r.Image)) } @@ -477,6 +539,56 @@ func (d *Docker) FindIP() (net.IP, error) { return ip, nil } +// A NetworkInterface is container's network interface information. +type NetworkInterface struct { + IPv4 net.IP + MAC net.HardwareAddr +} + +// ListNetworks returns the network interfaces of the container, keyed by +// Docker network name. +func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) { + const format = `{{json .NetworkSettings.Networks}}` + out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() + if err != nil { + return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err) + } + + networks := map[string]map[string]string{} + if err := json.Unmarshal(out, &networks); err != nil { + return nil, fmt.Errorf("error decoding network interfaces: %w", err) + } + + interfaces := map[string]NetworkInterface{} + for name, iface := range networks { + var netface NetworkInterface + + rawIP := strings.TrimSpace(iface["IPAddress"]) + if rawIP != "" { + ip := net.ParseIP(rawIP) + if ip == nil { + return nil, fmt.Errorf("invalid IP: %q", rawIP) + } + // Docker's IPAddress field is IPv4. The IPv6 address + // is stored in the GlobalIPv6Address field. + netface.IPv4 = ip + } + + rawMAC := strings.TrimSpace(iface["MacAddress"]) + if rawMAC != "" { + mac, err := net.ParseMAC(rawMAC) + if err != nil { + return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err) + } + netface.MAC = mac + } + + interfaces[name] = netface + } + + return interfaces, nil +} + // SandboxPid returns the PID to the sandbox process. func (d *Docker) SandboxPid() (int, error) { out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput() |