summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--README.md4
-rw-r--r--g3doc/user_guide/tutorials/docker.md2
-rw-r--r--images/tmpfile/Dockerfile4
-rw-r--r--pkg/buffer/safemem.go82
-rw-r--r--pkg/segment/BUILD2
-rw-r--r--pkg/segment/set.go400
-rw-r--r--pkg/segment/test/BUILD18
-rw-r--r--pkg/segment/test/segment_test.go397
-rw-r--r--pkg/segment/test/set_functions.go32
-rw-r--r--pkg/sentry/fs/g3doc/.gitignore1
-rw-r--r--pkg/sentry/fs/g3doc/fuse.md260
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go48
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go2
-rw-r--r--pkg/sentry/kernel/pipe/BUILD2
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go6
-rw-r--r--pkg/sentry/kernel/pipe/pipe_unsafe.go35
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go219
-rw-r--r--pkg/sentry/mm/BUILD1
-rw-r--r--pkg/sentry/mm/vma.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go286
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go4
-rw-r--r--pkg/sentry/vfs/file_description.go5
-rw-r--r--pkg/test/dockerutil/dockerutil.go116
-rw-r--r--test/e2e/integration_test.go56
-rw-r--r--test/packetimpact/README.md21
-rw-r--r--test/packetimpact/netdevs/BUILD15
-rw-r--r--test/packetimpact/netdevs/netdevs.go104
-rw-r--r--test/packetimpact/runner/BUILD20
-rw-r--r--test/packetimpact/runner/defs.bzl (renamed from test/packetimpact/tests/defs.bzl)13
-rw-r--r--test/packetimpact/runner/packetimpact_test.go312
-rw-r--r--test/packetimpact/testbench/BUILD1
-rw-r--r--test/packetimpact/testbench/connections.go4
-rw-r--r--test/packetimpact/testbench/dut.go6
-rw-r--r--test/packetimpact/testbench/rawsockets.go3
-rw-r--r--test/packetimpact/testbench/testbench.go31
-rw-r--r--test/packetimpact/tests/BUILD7
-rwxr-xr-xtest/packetimpact/tests/test_runner.sh325
-rw-r--r--test/runner/runner.go8
-rw-r--r--test/syscalls/linux/socket.cc13
-rw-r--r--test/syscalls/linux/splice.cc49
-rw-r--r--test/util/test_util.cc14
-rw-r--r--test/util/test_util.h1
-rw-r--r--tools/go_generics/generics.go4
45 files changed, 2440 insertions, 500 deletions
diff --git a/Makefile b/Makefile
index 7f382695d..2bcb85e9b 100644
--- a/Makefile
+++ b/Makefile
@@ -116,7 +116,7 @@ unit-tests: ## Runs all unit tests in pkg runsc and tools.
.PHONY: unit-tests
tests: ## Runs all local ptrace system call tests.
- @$(MAKE) test OPTIONS="--test_tag_filter runsc_ptrace test/syscalls/..."
+ @$(MAKE) test OPTIONS="--test_tag_filters runsc_ptrace test/syscalls/..."
.PHONY: tests
##
diff --git a/README.md b/README.md
index b1ed3b4ce..ce3947907 100644
--- a/README.md
+++ b/README.md
@@ -74,7 +74,7 @@ make tests
To run specific tests, you can specify the target:
```
-make test TARGET="//runsc:version_test"
+make test TARGETS="//runsc:version_test"
```
### Using `go get`
@@ -97,7 +97,7 @@ development on this branch is not supported. Development should occur on the
## Community & Governance
-See [GOVERNANCE.md](GOVERANCE.md) for project governance information.
+See [GOVERNANCE.md](GOVERNANCE.md) for project governance information.
The [gvisor-users mailing list][gvisor-users-list] and
[gvisor-dev mailing list][gvisor-dev-list] are good starting points for
diff --git a/g3doc/user_guide/tutorials/docker.md b/g3doc/user_guide/tutorials/docker.md
index c0a3db506..705560038 100644
--- a/g3doc/user_guide/tutorials/docker.md
+++ b/g3doc/user_guide/tutorials/docker.md
@@ -1,4 +1,4 @@
-# WorkPress with Docker
+# WordPress with Docker
This page shows you how to deploy a sample [WordPress][wordpress] site using
[Docker][docker].
diff --git a/images/tmpfile/Dockerfile b/images/tmpfile/Dockerfile
new file mode 100644
index 000000000..e3816c8cb
--- /dev/null
+++ b/images/tmpfile/Dockerfile
@@ -0,0 +1,4 @@
+# Create file under /tmp to ensure files inside '/tmp' are not overridden.
+FROM alpine:3.11.5
+RUN mkdir -p /tmp/foo \
+ && echo 123 > /tmp/foo/file.txt
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
index 0e5b86344..b789e56e9 100644
--- a/pkg/buffer/safemem.go
+++ b/pkg/buffer/safemem.go
@@ -28,12 +28,11 @@ func (b *buffer) ReadBlock() safemem.Block {
return safemem.BlockFromSafeSlice(b.ReadSlice())
}
-// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
-//
-// This will advance the write index.
-func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- need := int(srcs.NumBytes())
- if need == 0 {
+// WriteFromSafememReader writes up to count bytes from r to v and advances the
+// write index by the number of bytes written. It calls r.ReadToBlocks() at
+// most once.
+func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) {
+ if count == 0 {
return 0, nil
}
@@ -50,32 +49,33 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
}
// Does the last block have sufficient capacity alone?
- if l := firstBuf.WriteSize(); l >= need {
- dst = safemem.BlockSeqOf(firstBuf.WriteBlock())
+ if l := uint64(firstBuf.WriteSize()); l >= count {
+ dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count))
} else {
// Append blocks until sufficient.
- need -= l
+ count -= l
blocks = append(blocks, firstBuf.WriteBlock())
- for need > 0 {
+ for count > 0 {
emptyBuf := bufferPool.Get().(*buffer)
v.data.PushBack(emptyBuf)
- need -= emptyBuf.WriteSize()
- blocks = append(blocks, emptyBuf.WriteBlock())
+ block := emptyBuf.WriteBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
}
dst = safemem.BlockSeqFromSlice(blocks)
}
- // Perform the copy.
- n, err := safemem.CopySeq(dst, srcs)
+ // Perform I/O.
+ n, err := r.ReadToBlocks(dst)
v.size += int64(n)
// Update all indices.
- for left := int(n); left > 0; firstBuf = firstBuf.Next() {
- if l := firstBuf.WriteSize(); left >= l {
+ for left := n; left > 0; firstBuf = firstBuf.Next() {
+ if l := firstBuf.WriteSize(); left >= uint64(l) {
firstBuf.WriteMove(l) // Whole block.
- left -= l
+ left -= uint64(l)
} else {
- firstBuf.WriteMove(left) // Partial block.
+ firstBuf.WriteMove(int(left)) // Partial block.
left = 0
}
}
@@ -83,14 +83,16 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
return n, err
}
-// ReadToBlocks implements safemem.Reader.ReadToBlocks.
-//
-// This will not advance the read index; the caller should follow
-// this call with a call to TrimFront in order to remove the read
-// data from the buffer. This is done to support pipe sematics.
-func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- need := int(dsts.NumBytes())
- if need == 0 {
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the
+// write index by the number of bytes written.
+func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes())
+}
+
+// ReadToSafememWriter reads up to count bytes from v to w. It does not advance
+// the read index. It calls w.WriteFromBlocks() at most once.
+func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) {
+ if count == 0 {
return 0, nil
}
@@ -105,25 +107,27 @@ func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
}
// Is all the data in a single block?
- if l := firstBuf.ReadSize(); l >= need {
- src = safemem.BlockSeqOf(firstBuf.ReadBlock())
+ if l := uint64(firstBuf.ReadSize()); l >= count {
+ src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count))
} else {
// Build a list of all the buffers.
- need -= l
+ count -= l
blocks = append(blocks, firstBuf.ReadBlock())
- for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() {
- need -= buf.ReadSize()
- blocks = append(blocks, buf.ReadBlock())
+ for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() {
+ block := buf.ReadBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
}
src = safemem.BlockSeqFromSlice(blocks)
}
- // Perform the copy.
- n, err := safemem.CopySeq(dsts, src)
-
- // See above: we would normally advance the read index here, but we
- // don't do that in order to support pipe semantics. We rely on a
- // separate call to TrimFront() in this case.
+ // Perform I/O. As documented, we don't advance the read index.
+ return w.WriteFromBlocks(src)
+}
- return n, err
+// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the
+// read index by the number of bytes read, such that it's only safe to call if
+// the caller guarantees that ReadToBlocks will only be called once.
+func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes())
}
diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD
index 1b487b887..f57ccc170 100644
--- a/pkg/segment/BUILD
+++ b/pkg/segment/BUILD
@@ -21,6 +21,8 @@ go_template(
],
opt_consts = [
"minDegree",
+ # trackGaps must either be 0 or 1.
+ "trackGaps",
],
types = [
"Key",
diff --git a/pkg/segment/set.go b/pkg/segment/set.go
index 03e4f258f..1a17ad9cb 100644
--- a/pkg/segment/set.go
+++ b/pkg/segment/set.go
@@ -36,6 +36,34 @@ type Range interface{}
// Value is a required type parameter.
type Value interface{}
+// trackGaps is an optional parameter.
+//
+// If trackGaps is 1, the Set will track maximum gap size recursively,
+// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this
+// case, Key must be an unsigned integer.
+//
+// trackGaps must be 0 or 1.
+const trackGaps = 0
+
+var _ = uint8(trackGaps << 7) // Will fail if not zero or one.
+
+// dynamicGap is a type that disappears if trackGaps is 0.
+type dynamicGap [trackGaps]Key
+
+// Get returns the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *dynamicGap) Get() Key {
+ return d[:][0]
+}
+
+// Set sets the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *dynamicGap) Set(v Key) {
+ d[:][0] = v
+}
+
// Functions is a required type parameter that must be a struct implementing
// the methods defined by Functions.
type Functions interface {
@@ -327,8 +355,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator {
}
if prev.Ok() && prev.End() == r.Start {
if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
prev.SetEndUnchecked(r.End)
prev.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
if next.Ok() && next.Start() == r.End {
val = mval
if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
@@ -342,11 +374,16 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator {
}
if next.Ok() && next.Start() == r.End {
if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
next.SetStartUnchecked(r.Start)
next.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return next
}
}
+ // InsertWithoutMergingUnchecked will maintain maxGap if necessary.
return s.InsertWithoutMergingUnchecked(gap, r, val)
}
@@ -373,11 +410,15 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator
// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator {
gap = gap.node.rebalanceBeforeInsert(gap)
+ splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get())
copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
gap.node.keys[gap.index] = r
gap.node.values[gap.index] = val
gap.node.nrSegments++
+ if splitMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return Iterator{gap.node, gap.index}
}
@@ -399,12 +440,23 @@ func (s *Set) Remove(seg Iterator) GapIterator {
// overlap.
seg.SetRangeUnchecked(victim.Range())
seg.SetValue(victim.Value())
+ // Need to update the nextAdjacentNode's maxGap because the gap in between
+ // must have been modified by updating seg.Range() to victim.Range().
+ // seg.NextSegment() must exist since the last segment can't be in a
+ // non-leaf node.
+ nextAdjacentNode := seg.NextSegment().node
+ if trackGaps != 0 {
+ nextAdjacentNode.updateMaxGapLeaf()
+ }
return s.Remove(victim).NextGap()
}
copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
seg.node.nrSegments--
+ if trackGaps != 0 {
+ seg.node.updateMaxGapLeaf()
+ }
return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index})
}
@@ -455,6 +507,7 @@ func (s *Set) MergeUnchecked(first, second Iterator) Iterator {
// overlaps second.
first.SetEndUnchecked(second.End())
first.SetValue(mval)
+ // Remove will handle the maxGap update if necessary.
return s.Remove(second).PrevSegment()
}
}
@@ -631,6 +684,12 @@ type node struct {
// than "isLeaf" because false must be the correct value for an empty root.
hasChildren bool
+ // The longest gap within this node. If the node is a leaf, it's simply the
+ // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys
+ // including the 0th and nrSegments-th gap possibly shared with its upper-level
+ // nodes; if it's a non-leaf node, it's the max of all children's maxGap.
+ maxGap dynamicGap
+
// Nodes store keys and values in separate arrays to maximize locality in
// the common case (scanning keys for lookup).
keys [maxDegree - 1]Range
@@ -676,12 +735,12 @@ func (n *node) nextSibling() *node {
// required for insertion, and returns an updated iterator to the position
// represented by gap.
func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
- if n.parent != nil {
- gap = n.parent.rebalanceBeforeInsert(gap)
- }
if n.nrSegments < maxDegree-1 {
return gap
}
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
if n.parent == nil {
// n is root. Move all segments before and after n's median segment
// into new child nodes adjacent to the median segment, which is now
@@ -719,6 +778,13 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
n.hasChildren = true
n.children[0] = left
n.children[1] = right
+ // In this case, n's maxGap won't violated as it's still the root,
+ // but the left and right children should be updated locally as they
+ // are newly split from n.
+ if trackGaps != 0 {
+ left.updateMaxGapLocal()
+ right.updateMaxGapLocal()
+ }
if gap.node != n {
return gap
}
@@ -758,6 +824,12 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
}
}
n.nrSegments = minDegree - 1
+ // MaxGap of n's parent is not violated because the segments within is not changed.
+ // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
// gap.node can't be n.parent because gaps are always in leaf nodes.
if gap.node != n {
return gap
@@ -821,6 +893,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
}
n.nrSegments++
sibling.nrSegments--
+ // n's parent's maxGap does not need to be updated as its content is unmodified.
+ // n and its sibling must be updated with (new) maxGap because of the shift of keys.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling && gap.index == sibling.nrSegments {
return GapIterator{n, 0}
}
@@ -849,6 +927,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
}
n.nrSegments++
sibling.nrSegments--
+ // n's parent's maxGap does not need to be updated as its content is unmodified.
+ // n and its sibling must be updated with (new) maxGap because of the shift of keys.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling {
if gap.index == 0 {
return GapIterator{n, n.nrSegments}
@@ -886,6 +970,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
p.children[0] = nil
p.children[1] = nil
}
+ // No need to update maxGap of p as its content is not changed.
if gap.node == left {
return GapIterator{p, gap.index}
}
@@ -932,11 +1017,152 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
}
p.children[p.nrSegments] = nil
p.nrSegments--
+ // Update maxGap of left locally, no need to change p and right because
+ // p's contents is not changed and right is already invalid.
+ if trackGaps != 0 {
+ left.updateMaxGapLocal()
+ }
// This process robs p of one segment, so recurse into rebalancing p.
n = p
}
}
+// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no
+// necessary update.
+//
+// Preconditions: n must be a leaf node, trackGaps must be 1.
+func (n *node) updateMaxGapLeaf() {
+ if n.hasChildren {
+ panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n))
+ }
+ max := n.calculateMaxGapLeaf()
+ if max == n.maxGap.Get() {
+ // If new max equals the old maxGap, no update is needed.
+ return
+ }
+ oldMax := n.maxGap.Get()
+ n.maxGap.Set(max)
+ if max > oldMax {
+ // Grow ancestor maxGaps.
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() >= max {
+ // p and its ancestors already contain an equal or larger gap.
+ break
+ }
+ // Only if new maxGap is larger than parent's
+ // old maxGap, propagate this update to parent.
+ p.maxGap.Set(max)
+ }
+ return
+ }
+ // Shrink ancestor maxGaps.
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() > oldMax {
+ // p and its ancestors still contain a larger gap.
+ break
+ }
+ // If new max is smaller than the old maxGap, and this gap used
+ // to be the maxGap of its parent, iterate parent's children
+ // and calculate parent's new maxGap.(It's probable that parent
+ // has two children with the old maxGap, but we need to check it anyway.)
+ parentNewMax := p.calculateMaxGapInternal()
+ if p.maxGap.Get() == parentNewMax {
+ // p and its ancestors still contain a gap of at least equal size.
+ break
+ }
+ // If p's new maxGap differs from the old one, propagate this update.
+ p.maxGap.Set(parentNewMax)
+ }
+}
+
+// updateMaxGapLocal updates maxGap of the calling node solely with no
+// propagation to ancestor nodes.
+//
+// Precondition: trackGaps must be 1.
+func (n *node) updateMaxGapLocal() {
+ if !n.hasChildren {
+ // Leaf node iterates its gaps.
+ n.maxGap.Set(n.calculateMaxGapLeaf())
+ } else {
+ // Non-leaf node iterates its children.
+ n.maxGap.Set(n.calculateMaxGapInternal())
+ }
+}
+
+// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the
+// max.
+//
+// Preconditions: n must be a leaf node.
+func (n *node) calculateMaxGapLeaf() Key {
+ max := GapIterator{n, 0}.Range().Length()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := (GapIterator{n, i}).Range().Length(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// calculateMaxGapInternal iterates children's maxGap within an internal node n
+// and calculate the max.
+//
+// Preconditions: n must be a non-leaf node.
+func (n *node) calculateMaxGapInternal() Key {
+ max := n.children[0].maxGap.Get()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := n.children[i].maxGap.Get(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// searchFirstLargeEnoughGap returns the first gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator {
+ if n.maxGap.Get() < minSize {
+ return GapIterator{}
+ }
+ if n.hasChildren {
+ for i := 0; i <= n.nrSegments; i++ {
+ if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ currentGap := GapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// searchLastLargeEnoughGap returns the last gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator {
+ if n.maxGap.Get() < minSize {
+ return GapIterator{}
+ }
+ if n.hasChildren {
+ for i := n.nrSegments; i >= 0; i-- {
+ if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := n.nrSegments; i >= 0; i-- {
+ currentGap := GapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
// A Iterator is conceptually one of:
//
// - A pointer to a segment in a set; or
@@ -1243,6 +1469,122 @@ func (gap GapIterator) NextGap() GapIterator {
return seg.NextGap()
}
+// NextLargeEnoughGap returns the iterated gap's first next gap with larger
+// length than minSize. If not found, return a terminal gap iterator (does NOT
+// include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator {
+ if trackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments {
+ // If gap is the trailing gap of an non-leaf node,
+ // translate it to the equivalent gap on leaf level.
+ gap.node = gap.NextSegment().node
+ gap.index = 0
+ return gap.nextLargeEnoughGapHelper(minSize)
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the trailing gap of a non-leaf node.
+func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator {
+ // Crawl up the tree if no large enough gap in current node or the
+ // current gap is the trailing one on leaf level.
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ // If no large enough gap throughout the whole set, return a terminal
+ // gap iterator.
+ if gap.node == nil {
+ return GapIterator{}
+ }
+ // Iterate subsequent gaps.
+ gap.index++
+ for gap.index <= gap.node.nrSegments {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index++
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == gap.node.nrSegments {
+ // If gap is the trailing gap of a non-leaf node, crawl up to
+ // parent again and do recursion.
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or
+// equal length than minSize. If not found, return a terminal gap iterator
+// (does NOT include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator {
+ if trackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == 0 {
+ // If gap is the first gap of an non-leaf node,
+ // translate it to the equivalent gap on leaf level.
+ gap.node = gap.PrevSegment().node
+ gap.index = gap.node.nrSegments
+ return gap.prevLargeEnoughGapHelper(minSize)
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the first gap of a non-leaf node.
+func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator {
+ // Crawl up the tree if no large enough gap in current node or the
+ // current gap is the first one on leaf level.
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ // If no large enough gap throughout the whole set, return a terminal
+ // gap iterator.
+ if gap.node == nil {
+ return GapIterator{}
+ }
+ // Iterate previous gaps.
+ gap.index--
+ for gap.index >= 0 {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index--
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == 0 {
+ // If gap is the first gap of a non-leaf node, crawl up to
+ // parent again and do recursion.
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
// segmentBeforePosition returns the predecessor segment of the position given
// by n.children[i], which may or may not contain a child. If no such segment
// exists, segmentBeforePosition returns a terminal iterator.
@@ -1271,7 +1613,7 @@ func segmentAfterPosition(n *node, i int) Iterator {
func zeroValueSlice(slice []Value) {
// TODO(jamieliu): check if Go is actually smart enough to optimize a
- // ClearValue that assigns nil to a memset here
+ // ClearValue that assigns nil to a memset here.
for i := range slice {
Functions{}.ClearValue(&slice[i])
}
@@ -1310,7 +1652,15 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
}
buf.WriteString(prefix)
- buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ if n.hasChildren {
+ if trackGaps != 0 {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get()))
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
}
if child := n.children[n.nrSegments]; child != nil {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
@@ -1362,3 +1712,43 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error {
}
return nil
}
+
+// segmentTestCheck returns an error if s is incorrectly sorted, does not
+// contain exactly expectedSegments segments, or contains a segment which
+// fails the passed check.
+//
+// This should be used only for testing, and has been added to this package for
+// templating convenience.
+func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error {
+ havePrev := false
+ prev := Key(0)
+ nrSegments := 0
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ next := seg.Start()
+ if havePrev && prev >= next {
+ return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
+ }
+ if segFunc != nil {
+ if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil {
+ return err
+ }
+ }
+ prev = next
+ havePrev = true
+ nrSegments++
+ }
+ if nrSegments != expectedSegments {
+ return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+ }
+ return nil
+}
+
+// countSegments counts the number of segments in the set.
+//
+// Similar to Check, this should only be used for testing.
+func (s *Set) countSegments() (segments int) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ segments++
+ }
+ return segments
+}
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index f2d8462d8..131bf09b9 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -29,10 +29,28 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "gap_set",
+ out = "gap_set.go",
+ consts = {
+ "trackGaps": "1",
+ },
+ package = "segment",
+ prefix = "gap",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "int",
+ "Range": "Range",
+ "Value": "int",
+ "Functions": "gapSetFunctions",
+ },
+)
+
go_library(
name = "segment",
testonly = 1,
srcs = [
+ "gap_set.go",
"int_range.go",
"int_set.go",
"set_functions.go",
diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go
index 97b16c158..85fa19096 100644
--- a/pkg/segment/test/segment_test.go
+++ b/pkg/segment/test/segment_test.go
@@ -17,6 +17,7 @@ package segment
import (
"fmt"
"math/rand"
+ "reflect"
"testing"
)
@@ -32,61 +33,65 @@ const (
// valueOffset is the difference between the value and start of test
// segments.
valueOffset = 100000
+
+ // intervalLength is the interval used by random gap tests.
+ intervalLength = 10
)
func shuffle(xs []int) {
- for i := range xs {
- j := rand.Intn(i + 1)
- xs[i], xs[j] = xs[j], xs[i]
- }
+ rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] })
}
-func randPermutation(size int) []int {
+func randIntervalPermutation(size int) []int {
p := make([]int, size)
for i := range p {
- p[i] = i
+ p[i] = intervalLength * i
}
shuffle(p)
return p
}
-// checkSet returns an error if s is incorrectly sorted, does not contain
-// exactly expectedSegments segments, or contains a segment for which val !=
-// key + valueOffset.
-func checkSet(s *Set, expectedSegments int) error {
- havePrev := false
- prev := 0
- nrSegments := 0
- for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- next := seg.Start()
- if havePrev && prev >= next {
- return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
- }
- if got, want := seg.Value(), seg.Start()+valueOffset; got != want {
- return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want)
- }
- prev = next
- havePrev = true
- nrSegments++
- }
- if nrSegments != expectedSegments {
- return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+// validate can be passed to Check.
+func validate(nr int, r Range, v int) error {
+ if got, want := v, r.Start+valueOffset; got != want {
+ return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want)
}
return nil
}
-// countSegmentsIn returns the number of segments in s.
-func countSegmentsIn(s *Set) int {
- var count int
- for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- count++
+// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well
+// maintained.
+func checkSetMaxGap(s *gapSet) error {
+ n := s.root
+ return checkNodeMaxGap(&n)
+}
+
+// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is
+// not well maintained.
+func checkNodeMaxGap(n *gapnode) error {
+ var max int
+ if !n.hasChildren {
+ max = n.calculateMaxGapLeaf()
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ child := n.children[i]
+ if err := checkNodeMaxGap(child); err != nil {
+ return err
+ }
+ if temp := child.maxGap.Get(); i == 0 || temp > max {
+ max = temp
+ }
+ }
+ }
+ if max != n.maxGap.Get() {
+ return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap)
}
- return count
+ return nil
}
func TestAddRandom(t *testing.T) {
var s Set
- order := randPermutation(testSize)
+ order := rand.Perm(testSize)
var nrInsertions int
for i, j := range order {
if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) {
@@ -94,12 +99,12 @@ func TestAddRandom(t *testing.T) {
break
}
nrInsertions++
- if err := checkSet(&s, nrInsertions); err != nil {
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
}
- if got, want := countSegmentsIn(&s), nrInsertions; got != want {
+ if got, want := s.countSegments(), nrInsertions; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -115,7 +120,156 @@ func TestRemoveRandom(t *testing.T) {
t.Fatalf("Failed to insert segment %d", i)
}
}
- order := randPermutation(testSize)
+ order := rand.Perm(testSize)
+ var nrRemovals int
+ for i, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
+ break
+ }
+ s.Remove(seg)
+ nrRemovals++
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestMaxGapAddRandom(t *testing.T) {
+ var s gapSet
+ order := rand.Perm(testSize)
+ var nrInsertions int
+ for i, j := range order {
+ if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order[:nrInsertions])
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapAddRandomWithRandomInterval(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize)
+ var nrInsertions int
+ for i, j := range order {
+ if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order[:nrInsertions])
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapAddRandomWithMerge(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize)
+ nrInsertions := 1
+ for i, j := range order {
+ if !s.Add(Range{j, j + intervalLength}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order)
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapRemoveRandom(t *testing.T) {
+ var s gapSet
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := rand.Perm(testSize)
+ var nrRemovals int
+ for i, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
+ break
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ nrRemovals++
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestMaxGapRemoveHalfRandom(t *testing.T) {
+ var s gapSet
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := randIntervalPermutation(testSize)
+ order = order[:testSize/2]
var nrRemovals int
for i, j := range order {
seg := s.FindSegment(j)
@@ -123,14 +277,19 @@ func TestRemoveRandom(t *testing.T) {
t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
break
}
+ temprange := seg.Range()
s.Remove(seg)
nrRemovals++
- if err := checkSet(&s, testSize-nrRemovals); err != nil {
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
}
- if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want {
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -140,6 +299,148 @@ func TestRemoveRandom(t *testing.T) {
}
}
+func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + intervalLength}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ shuffle(order)
+ var nrRemovals int
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ nrRemovals++
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestNextLargeEnoughGap(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ shuffle(order)
+ order = order[:testSize/2]
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ minSize := 7
+ var gapArr1 []int
+ for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) {
+ if gap.Range().Length() < minSize {
+ t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize)
+ } else {
+ gapArr1 = append(gapArr1, gap.Range().Start)
+ }
+ }
+ var gapArr2 []int
+ for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() {
+ if gap.Range().Length() >= minSize {
+ gapArr2 = append(gapArr2, gap.Range().Start)
+ }
+ }
+
+ if !reflect.DeepEqual(gapArr2, gapArr1) {
+ t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestPrevLargeEnoughGap(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ end := s.LastSegment().End()
+ shuffle(order)
+ order = order[:testSize/2]
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ minSize := 7
+ var gapArr1 []int
+ for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) {
+ if gap.Range().Length() < minSize {
+ t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize)
+ } else {
+ gapArr1 = append(gapArr1, gap.Range().Start)
+ }
+ }
+ var gapArr2 []int
+ for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() {
+ if gap.Range().Length() >= minSize {
+ gapArr2 = append(gapArr2, gap.Range().Start)
+ }
+ }
+ if !reflect.DeepEqual(gapArr2, gapArr1) {
+ t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
func TestAddSequentialAdjacent(t *testing.T) {
var s Set
var nrInsertions int
@@ -148,12 +449,12 @@ func TestAddSequentialAdjacent(t *testing.T) {
t.Fatalf("Failed to insert segment %d", i)
}
nrInsertions++
- if err := checkSet(&s, nrInsertions); err != nil {
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
}
- if got, want := countSegmentsIn(&s), nrInsertions; got != want {
+ if got, want := s.countSegments(), nrInsertions; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -202,12 +503,12 @@ func TestAddSequentialNonAdjacent(t *testing.T) {
t.Fatalf("Failed to insert segment %d", i)
}
nrInsertions++
- if err := checkSet(&s, nrInsertions); err != nil {
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
}
- if got, want := countSegmentsIn(&s), nrInsertions; got != want {
+ if got, want := s.countSegments(), nrInsertions; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -293,7 +594,7 @@ Tests:
var i int
for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
if i > len(test.final) {
- t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s)
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s)
continue Tests
}
if got, want := seg.Range(), test.final[i]; got != want {
@@ -351,7 +652,7 @@ Tests:
var i int
for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
if i > len(test.final) {
- t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s)
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s)
continue Tests
}
if got, want := seg.Range(), test.final[i]; got != want {
@@ -378,7 +679,7 @@ func benchmarkAddSequential(b *testing.B, size int) {
}
func benchmarkAddRandom(b *testing.B, size int) {
- order := randPermutation(size)
+ order := rand.Perm(size)
b.ResetTimer()
for n := 0; n < b.N; n++ {
@@ -416,7 +717,7 @@ func benchmarkFindRandom(b *testing.B, size int) {
b.Fatalf("Failed to insert segment %d", i)
}
}
- order := randPermutation(size)
+ order := rand.Perm(size)
b.ResetTimer()
for n := 0; n < b.N; n++ {
@@ -470,7 +771,7 @@ func benchmarkAddFindRemoveSequential(b *testing.B, size int) {
}
func benchmarkAddFindRemoveRandom(b *testing.B, size int) {
- order := randPermutation(size)
+ order := rand.Perm(size)
b.ResetTimer()
for n := 0; n < b.N; n++ {
diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go
index bcddb39bb..7cd895cc7 100644
--- a/pkg/segment/test/set_functions.go
+++ b/pkg/segment/test/set_functions.go
@@ -14,21 +14,16 @@
package segment
-// Basic numeric constants that we define because the math package doesn't.
-// TODO(nlacasse): These should be Math.MaxInt64/MinInt64?
-const (
- maxInt = int(^uint(0) >> 1)
- minInt = -maxInt - 1
-)
-
type setFunctions struct{}
-func (setFunctions) MinKey() int {
- return minInt
+// MinKey returns the minimum key for the set.
+func (s setFunctions) MinKey() int {
+ return -s.MaxKey() - 1
}
+// MaxKey returns the maximum key for the set.
func (setFunctions) MaxKey() int {
- return maxInt
+ return int(^uint(0) >> 1)
}
func (setFunctions) ClearValue(*int) {}
@@ -40,3 +35,20 @@ func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) {
func (setFunctions) Split(_ Range, val int, _ int) (int, int) {
return val, val
}
+
+type gapSetFunctions struct {
+ setFunctions
+}
+
+// MinKey is adjusted to make sure no add overflow would happen in test cases.
+// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length().
+//
+// Normally Keys should be unsigned to avoid these issues.
+func (s gapSetFunctions) MinKey() int {
+ return s.setFunctions.MinKey() / 2
+}
+
+// MaxKey returns the maximum key for the set.
+func (s gapSetFunctions) MaxKey() int {
+ return s.setFunctions.MaxKey() / 2
+}
diff --git a/pkg/sentry/fs/g3doc/.gitignore b/pkg/sentry/fs/g3doc/.gitignore
new file mode 100644
index 000000000..2d19fc766
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/.gitignore
@@ -0,0 +1 @@
+*.html
diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md
new file mode 100644
index 000000000..635cc009b
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/fuse.md
@@ -0,0 +1,260 @@
+# Foreword
+
+This document describes an on-going project to support FUSE filesystems within
+the sentry. This is intended to become the final documentation for this
+subsystem, and is therefore written in the past tense. However FUSE support is
+currently incomplete and the document will be updated as things progress.
+
+# FUSE: Filesystem in Userspace
+
+The sentry supports dispatching filesystem operations to a FUSE server, allowing
+FUSE filesystem to be used with a sandbox.
+
+## Overview
+
+FUSE has two main components:
+
+1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards
+ filesystem operations (usually initiated by syscalls) to the server.
+
+2. A server, which is a userspace daemon that implements the actual filesystem.
+
+The sentry implements the client component, which allows a server daemon running
+within the sandbox to implement a filesystem within the sandbox.
+
+A FUSE filesystem is initialized with `mount(2)`, typically with the help of a
+utility like `fusermount(1)`. Various mount options exist for establishing
+ownership and access permissions on the filesystem, but the most important mount
+option is a file descriptor used to establish communication between the client
+and server.
+
+The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation,
+the client and server use the FUSE protocol described in `fuse(4)` to service
+filesystem operations. See the "Protocol" section below for more information
+about this protocol. The core of the sentry support for FUSE is the client-side
+implementation of this protocol.
+
+## FUSE in the Sentry
+
+The sentry's FUSE client targets VFS2 and has the following components:
+
+- An implementation of `/dev/fuse`.
+
+- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting
+ VFS2, one point of contention may be the lack of inodes in VFS2. We can
+ tentatively implement a kernfs-based filesystem to bridge the gap in APIs.
+ The kernfs base functionality can serve the role of the Linux inode cache
+ and, the filesystem can map VFS2 syscalls to kernfs inode operations; see
+ the `kernfs.Inode` interface.
+
+The FUSE protocol lends itself well to marshaling with `go_marshal`. The various
+request and response packets can be defined in the ABI package and converted to
+and from the wire format using `go_marshal`.
+
+### Design Goals
+
+- While filesystem performance is always important, the sentry's FUSE support
+ is primarily concerned with compatibility, with performance as a secondary
+ concern.
+
+- Avoiding deadlocks from a hung server daemon.
+
+- Consider the potential for denial of service from a malicious server daemon.
+ Protecting itself from userspace is already a design goal for the sentry,
+ but needs additional consideration for FUSE. Normally, an operating system
+ doesn't rely on userspace to make progress with filesystem operations. Since
+ this changes with FUSE, it opens up the possibility of creating a chain of
+ dependencies controlled by userspace, which could affect an entire sandbox.
+ For example: a FUSE op can block a syscall, which could be holding a
+ subsystem lock, which can then block another task goroutine.
+
+### Milestones
+
+Below are some broad goals to aim for while implementing FUSE in the sentry.
+Many FUSE ops can be grouped into broad categories of functionality, and most
+ops can be implemented in parallel.
+
+#### Minimal client that can mount a trivial FUSE filesystem.
+
+- Implement `/dev/fuse`.
+
+- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`.
+
+#### Read-only mount with basic file operations
+
+- Implement the majority of file, directory and file descriptor FUSE ops. For
+ this milestone, we can skip uncommon or complex operations like mmap, mknod,
+ file locking, poll, and extended attributes. We can stub these out along
+ with any ops that modify the filesystem. The exact list of required ops are
+ to be determined, but the goal is to mount a real filesystem as read-only,
+ and be able to read contents from the filesystem in the sentry.
+
+#### Full read-write support
+
+- Implement the remaining FUSE ops and decide if we can omit rarely used
+ operations like ioctl.
+
+# Appendix
+
+## FUSE Protocol
+
+The FUSE protocol is a request-response protocol. All requests are initiated by
+the client. The wire-format for the protocol is raw c structs serialized to
+memory.
+
+All FUSE requests begin with the following request header:
+
+```c
+struct fuse_in_header {
+ uint32_t len; // Length of the request, including this header.
+ uint32_t opcode; // Requested operation.
+ uint64_t unique; // A unique identifier for this request.
+ uint64_t nodeid; // ID of the filesystem object being operated on.
+ uint32_t uid; // UID of the requesting process.
+ uint32_t gid; // GID of the requesting process.
+ uint32_t pid; // PID of the requesting process.
+ uint32_t padding;
+};
+```
+
+The request is then followed by a payload specific to the `opcode`.
+
+All responses begin with this response header:
+
+```c
+struct fuse_out_header {
+ uint32_t len; // Length of the response, including this header.
+ int32_t error; // Status of the request, 0 if success.
+ uint64_t unique; // The unique identifier from the corresponding request.
+};
+```
+
+The response payload also depends on the request `opcode`. If `error != 0`, the
+response payload must be empty.
+
+### Operations
+
+The following is a list of all FUSE operations used in `fuse_in_header.opcode`
+as of Linux v4.4, and a brief description of their purpose. These are defined in
+`uapi/linux/fuse.h`. Many of these have a corresponding request and response
+payload struct; `fuse(4)` has details for some of these. We also note how these
+operations map to the sentry virtual filesystem.
+
+#### FUSE meta-operations
+
+These operations are specific to FUSE and don't have a corresponding action in a
+generic filesystem.
+
+- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the
+ first message sent by the client after mount. This is used for version and
+ feature negotiation. This is related to `mount(2)`.
+- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`.
+- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the
+ `fuse_in_header.unique` value provided in the corresponding request header.
+ The client can send at most one of these per request, and will enter an
+ uninterruptible wait for a reply. The server is expected to reply promptly.
+- `FUSE_FORGET`: A hint to the server that server should evict the indicate
+ node from any caches. This is wired up to `(struct
+ super_operations).evict_inode` in Linux, which is in turned hooked as the
+ inode cache shrinker which is typically triggered by system memory pressure.
+- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`.
+
+#### Filesystem Syscalls
+
+These FUSE ops map directly to an equivalent filesystem syscall, or family of
+syscalls. The relevant syscalls have a similar name to the operation, unless
+otherwise noted.
+
+Node creation:
+
+- `FUSE_MKNOD`
+- `FUSE_MKDIR`
+- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which
+ atomically creates and opens a node.
+
+Node attributes and extended attributes:
+
+- `FUSE_GETATTR`
+- `FUSE_SETATTR`
+- `FUSE_SETXATTR`
+- `FUSE_GETXATTR`
+- `FUSE_LISTXATTR`
+- `FUSE_REMOVEXATTR`
+
+Node link manipulation:
+
+- `FUSE_READLINK`
+- `FUSE_LINK`
+- `FUSE_SYMLINK`
+- `FUSE_UNLINK`
+
+Directory operations:
+
+- `FUSE_RMDIR`
+- `FUSE_RENAME`
+- `FUSE_RENAME2`
+- `FUSE_OPENDIR`: `open(2)` for directories.
+- `FUSE_RELEASEDIR`: `close(2)` for directories.
+- `FUSE_READDIR`
+- `FUSE_READDIRPLUS`
+- `FUSE_FSYNCDIR`: `fsync(2)` for directories.
+- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is
+ reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path
+ component to a node. However the returned identifier is opaque to the
+ client. The server must remember this mapping, as this is how the client
+ will reference the node in the future.
+
+File operations:
+
+- `FUSE_OPEN`: `open(2)` for files.
+- `FUSE_RELEASE`: `close(2)` for files.
+- `FUSE_FSYNC`
+- `FUSE_FALLOCATE`
+- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`.
+- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`.
+
+File locking:
+
+- `FUSE_GETLK`
+- `FUSE_SETLK`
+- `FUSE_SETLKW`
+- `FUSE_COPY_FILE_RANGE`
+
+File descriptor operations:
+
+- `FUSE_IOCTL`
+- `FUSE_POLL`
+- `FUSE_LSEEK`
+
+Filesystem operations:
+
+- `FUSE_STATFS`
+
+#### Permissions
+
+- `FUSE_ACCESS` is used to check if a node is accessible, as part of many
+ syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the
+ sentry.
+
+#### I/O Operations
+
+These ops are used to read and write file pages. They're used to implement both
+I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`.
+
+- `FUSE_READ`
+- `FUSE_WRITE`
+
+#### Miscellaneous
+
+- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is
+ closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)`
+ syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the
+ sentry.
+- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed.
+- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?]
+
+# References
+
+- `fuse(4)` manpage.
+- Linux kernel FUSE documentation:
+ https://www.kernel.org/doc/html/latest/filesystems/fuse.html
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 6295f6b54..131da332f 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -84,12 +84,6 @@ type filesystem struct {
// devMinor is the filesystem's minor device number. devMinor is immutable.
devMinor uint32
- // uid and gid are the effective KUID and KGID of the filesystem's creator,
- // and are used as the owner and group for files that don't specify one.
- // uid and gid are immutable.
- uid auth.KUID
- gid auth.KGID
-
// renameMu serves two purposes:
//
// - It synchronizes path resolution with renaming initiated by this
@@ -122,6 +116,8 @@ type filesystemOptions struct {
fd int
aname string
interop InteropMode // derived from the "cache" mount option
+ dfltuid auth.KUID
+ dfltgid auth.KGID
msize uint32
version string
@@ -230,6 +226,15 @@ type InternalFilesystemOptions struct {
OpenSocketsByConnecting bool
}
+// _V9FS_DEFUID and _V9FS_DEFGID (from Linux's fs/9p/v9fs.h) are the default
+// UIDs and GIDs used for files that do not provide a specific owner or group
+// respectively.
+const (
+ // uint32(-2) doesn't work in Go.
+ _V9FS_DEFUID = auth.KUID(4294967294)
+ _V9FS_DEFGID = auth.KGID(4294967294)
+)
+
// Name implements vfs.FilesystemType.Name.
func (FilesystemType) Name() string {
return Name
@@ -315,6 +320,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
+ // Parse the default UID and GID.
+ fsopts.dfltuid = _V9FS_DEFUID
+ if dfltuidstr, ok := mopts["dfltuid"]; ok {
+ delete(mopts, "dfltuid")
+ dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ // In Linux, dfltuid is interpreted as a UID and is converted to a KUID
+ // in the caller's user namespace, but goferfs isn't
+ // application-mountable.
+ fsopts.dfltuid = auth.KUID(dfltuid)
+ }
+ fsopts.dfltgid = _V9FS_DEFGID
+ if dfltgidstr, ok := mopts["dfltgid"]; ok {
+ delete(mopts, "dfltgid")
+ dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.dfltgid = auth.KGID(dfltgid)
+ }
+
// Parse the 9P message size.
fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M
if msizestr, ok := mopts["msize"]; ok {
@@ -422,8 +452,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
client: client,
clock: ktime.RealtimeClockFromContext(ctx),
devMinor: devMinor,
- uid: creds.EffectiveKUID,
- gid: creds.EffectiveKGID,
syncableDentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
}
@@ -672,8 +700,8 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
file: file,
ino: qid.Path,
mode: uint32(attr.Mode),
- uid: uint32(fs.uid),
- gid: uint32(fs.gid),
+ uid: uint32(fs.opts.dfltuid),
+ gid: uint32(fs.opts.dfltgid),
blockSize: usermem.PageSize,
handle: handle{
fd: -1,
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 3f433d666..fee174375 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -312,7 +312,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
f := fd.inode().impl.(*regularFile)
if end := offset + srclen; end < offset {
// Overflow.
- return 0, syserror.EFBIG
+ return 0, syserror.EINVAL
}
var err error
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index f29dc0472..7bfa9075a 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -8,6 +8,7 @@ go_library(
"device.go",
"node.go",
"pipe.go",
+ "pipe_unsafe.go",
"pipe_util.go",
"reader.go",
"reader_writer.go",
@@ -20,6 +21,7 @@ go_library(
"//pkg/amutex",
"//pkg/buffer",
"//pkg/context",
+ "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 62c8691f1..79645d7d2 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -207,7 +207,10 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
+ return p.readLocked(ctx, ops)
+}
+func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) {
// Is the pipe empty?
if p.view.Size() == 0 {
if !p.HasWriters() {
@@ -246,7 +249,10 @@ type writeOps struct {
func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
+ return p.writeLocked(ctx, ops)
+}
+func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
// Can't write to a pipe with no readers.
if !p.HasReaders() {
return 0, syscall.EPIPE
diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go
new file mode 100644
index 000000000..dd60cba24
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "unsafe"
+)
+
+// lockTwoPipes locks both x.mu and y.mu in an order that is guaranteed to be
+// consistent for both lockTwoPipes(x, y) and lockTwoPipes(y, x), such that
+// concurrent calls cannot deadlock.
+//
+// Preconditions: x != y.
+func lockTwoPipes(x, y *Pipe) {
+ // Lock the two pipes in order of increasing address.
+ if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) {
+ x.mu.Lock()
+ y.mu.Lock()
+ } else {
+ y.mu.Lock()
+ x.mu.Lock()
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index b54f08a30..2602bed72 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -16,7 +16,9 @@ package pipe
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -150,7 +152,9 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) *
return &fd.vfsfd
}
-// VFSPipeFD implements vfs.FileDescriptionImpl for pipes.
+// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
+// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to
+// other FileDescriptions for splice(2) and tee(2).
type VFSPipeFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -229,3 +233,216 @@ func (fd *VFSPipeFD) PipeSize() int64 {
func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) {
return fd.pipe.SetFifoSize(size)
}
+
+// IOSequence returns a useremm.IOSequence that reads up to count bytes from,
+// or writes up to count bytes to, fd.
+func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence {
+ return usermem.IOSequence{
+ IO: fd,
+ Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}),
+ }
+}
+
+// CopyIn implements usermem.IO.CopyIn.
+func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(dst))
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return int64(len(dst))
+ },
+ limit: func(l int64) {
+ dst = dst[:l]
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadAt(dst, 0)
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// CopyOut implements usermem.IO.CopyOut.
+func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(src))
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return int64(len(src))
+ },
+ limit: func(l int64) {
+ src = src[:l]
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Append(src)
+ return int64(len(src)), nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// ZeroOut implements usermem.IO.ZeroOut.
+func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+ origCount := toZero
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return toZero
+ },
+ limit: func(l int64) {
+ toZero = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Grow(view.Size()+toZero, true /* zero */)
+ return toZero, nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyInTo implements usermem.IO.CopyInTo.
+func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToSafememWriter(dst, uint64(count))
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyOutFrom implements usermem.IO.CopyOutFrom.
+func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromSafememReader(src, uint64(count))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// SwapUint32 implements usermem.IO.SwapUint32.
+func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+ // How did a pipe get passed as the virtual address space to futex(2)?
+ panic("VFSPipeFD.SwapUint32 called unexpectedly")
+}
+
+// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
+func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly")
+}
+
+// LoadUint32 implements usermem.IO.LoadUint32.
+func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.LoadUint32 called unexpectedly")
+}
+
+// Splice reads up to count bytes from src and writes them to dst. It returns
+// the number of bytes moved.
+//
+// Preconditions: count > 0.
+func Splice(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, true /* removeFromSrc */)
+}
+
+// Tee reads up to count bytes from src and writes them to dst, without
+// removing the read bytes from src. It returns the number of bytes copied.
+//
+// Preconditions: count > 0.
+func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, false /* removeFromSrc */)
+}
+
+// Preconditions: count > 0.
+func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) {
+ if dst.pipe == src.pipe {
+ return 0, syserror.EINVAL
+ }
+
+ lockTwoPipes(dst.pipe, src.pipe)
+ defer dst.pipe.mu.Unlock()
+ defer src.pipe.mu.Unlock()
+
+ n, err := dst.pipe.writeLocked(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(dstView *buffer.View) (int64, error) {
+ return src.pipe.readLocked(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(srcView *buffer.View) (int64, error) {
+ n, err := srcView.ReadToSafememWriter(dstView, uint64(count))
+ if n > 0 && removeFromSrc {
+ srcView.TrimFront(int64(n))
+ }
+ return int64(n), err
+ },
+ })
+ },
+ })
+ if n > 0 {
+ dst.pipe.Notify(waiter.EventIn)
+ src.pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 73591dab7..a036ce53c 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -25,6 +25,7 @@ go_template_instance(
out = "vma_set.go",
consts = {
"minDegree": "8",
+ "trackGaps": "1",
},
imports = {
"usermem": "gvisor.dev/gvisor/pkg/usermem",
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index 9a14e69e6..16d8207e9 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -195,7 +195,7 @@ func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange {
// Preconditions: mm.mappingMu must be locked.
func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
- for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextGap() {
+ for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) {
if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
// Can we shift up to match the alignment?
if offset := uint64(gr.Start) % alignment; offset != 0 {
@@ -214,7 +214,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou
// Preconditions: mm.mappingMu must be locked.
func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
- for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevGap() {
+ for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) {
if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
// Can we shift down to match the alignment?
start := gr.End - usermem.Addr(length)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index f882ef840..d56927ff5 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -22,6 +22,7 @@ go_library(
"setstat.go",
"signal.go",
"socket.go",
+ "splice.go",
"stat.go",
"stat_amd64.go",
"stat_arm64.go",
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
new file mode 100644
index 000000000..8f3c22a02
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -0,0 +1,286 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Splice implements Linux syscall splice(2).
+func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ inOffsetPtr := args[1].Pointer()
+ outFD := args[2].Int()
+ outOffsetPtr := args[3].Pointer()
+ count := int64(args[4].SizeT())
+ flags := args[5].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // At least one file description must represent a pipe.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe && !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy in offsets.
+ inOffset := int64(-1)
+ if inOffsetPtr != 0 {
+ if inIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ if inOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+ outOffset := int64(-1)
+ if outOffsetPtr != 0 {
+ if outIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if outFile.Options().DenyPWrite {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ if outOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Move data.
+ var (
+ n int64
+ err error
+ inCh chan struct{}
+ outCh chan struct{}
+ )
+ for {
+ // If both input and output are pipes, delegate to the pipe
+ // implementation. Otherwise, exactly one end is a pipe, which we
+ // ensure is consistently ordered after the non-pipe FD's locks by
+ // passing the pipe FD as usermem.IO to the non-pipe end.
+ switch {
+ case inIsPipe && outIsPipe:
+ n, err = pipe.Splice(t, outPipeFD, inPipeFD, count)
+ case inIsPipe:
+ if outOffset != -1 {
+ n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{})
+ outOffset += n
+ } else {
+ n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{})
+ }
+ case outIsPipe:
+ if inOffset != -1 {
+ n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{})
+ inOffset += n
+ } else {
+ n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
+ }
+ }
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the splice operation.
+ if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, eventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(inCh); err != nil {
+ break
+ }
+ }
+ if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, eventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(outCh); err != nil {
+ break
+ }
+ }
+ }
+
+ // Copy updated offsets out.
+ if inOffsetPtr != 0 {
+ if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+ if outOffsetPtr != 0 {
+ if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Tee implements Linux syscall tee(2).
+func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ outFD := args[1].Int()
+ count := int64(args[2].SizeT())
+ flags := args[3].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // Both file descriptions must represent pipes.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe || !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy data.
+ var (
+ inCh chan struct{}
+ outCh chan struct{}
+ )
+ for {
+ n, err := pipe.Tee(t, outPipeFD, inPipeFD, count)
+ if n != 0 {
+ return uintptr(n), nil, nil
+ }
+ if err != syserror.ErrWouldBlock || nonBlock {
+ return 0, nil, err
+ }
+
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the tee operation.
+ if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, eventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err := t.Block(inCh); err != nil {
+ return 0, nil, err
+ }
+ }
+ if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, eventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err := t.Block(outCh); err != nil {
+ return 0, nil, err
+ }
+ }
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index a332d01bd..083fdcf82 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -134,8 +134,8 @@ func Override() {
s.Table[269] = syscalls.Supported("faccessat", Faccessat)
s.Table[270] = syscalls.Supported("pselect", Pselect)
s.Table[271] = syscalls.Supported("ppoll", Ppoll)
- delete(s.Table, 275) // splice
- delete(s.Table, 276) // tee
+ s.Table[275] = syscalls.Supported("splice", Splice)
+ s.Table[276] = syscalls.Supported("tee", Tee)
s.Table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
s.Table[280] = syscalls.Supported("utimensat", Utimensat)
s.Table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index cfabd936c..bb294563d 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -210,6 +210,11 @@ func (fd *FileDescription) VirtualDentry() VirtualDentry {
return fd.vd
}
+// Options returns the options passed to fd.Init().
+func (fd *FileDescription) Options() FileDescriptionOptions {
+ return fd.opts
+}
+
// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
func (fd *FileDescription) StatusFlags() uint32 {
return atomic.LoadUint32(&fd.statusFlags)
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
index 5f2af9f3b..c45d2ecbc 100644
--- a/pkg/test/dockerutil/dockerutil.go
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -148,6 +148,62 @@ func (m MountMode) String() string {
panic(fmt.Sprintf("invalid mode: %d", m))
}
+// DockerNetwork contains the name of a docker network.
+type DockerNetwork struct {
+ logger testutil.Logger
+ Name string
+ Subnet *net.IPNet
+ containers []*Docker
+}
+
+// NewDockerNetwork sets up the struct for a Docker network. Names of networks
+// will be unique.
+func NewDockerNetwork(logger testutil.Logger) *DockerNetwork {
+ return &DockerNetwork{
+ logger: logger,
+ Name: testutil.RandomID(logger.Name()),
+ }
+}
+
+// Create calls 'docker network create'.
+func (n *DockerNetwork) Create(args ...string) error {
+ a := []string{"docker", "network", "create"}
+ if n.Subnet != nil {
+ a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet))
+ }
+ a = append(a, args...)
+ a = append(a, n.Name)
+ return testutil.Command(n.logger, a...).Run()
+}
+
+// Connect calls 'docker network connect' with the arguments provided.
+func (n *DockerNetwork) Connect(container *Docker, args ...string) error {
+ a := []string{"docker", "network", "connect"}
+ a = append(a, args...)
+ a = append(a, n.Name, container.Name)
+ if err := testutil.Command(n.logger, a...).Run(); err != nil {
+ return err
+ }
+ n.containers = append(n.containers, container)
+ return nil
+}
+
+// Cleanup cleans up the docker network and all the containers attached to it.
+func (n *DockerNetwork) Cleanup() error {
+ for _, c := range n.containers {
+ // Don't propagate the error, it might be that the container
+ // was already cleaned up.
+ if err := c.Kill(); err != nil {
+ n.logger.Logf("unable to kill container during cleanup: %s", err)
+ }
+ }
+
+ if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil {
+ return err
+ }
+ return nil
+}
+
// Docker contains the name and the runtime of a docker container.
type Docker struct {
logger testutil.Logger
@@ -162,9 +218,13 @@ type Docker struct {
//
// Names of containers will be unique.
func MakeDocker(logger testutil.Logger) *Docker {
+ // Slashes are not allowed in container names.
+ name := testutil.RandomID(logger.Name())
+ name = strings.ReplaceAll(name, "/", "-")
+
return &Docker{
logger: logger,
- Name: testutil.RandomID(logger.Name()),
+ Name: name,
Runtime: *runtime,
}
}
@@ -309,7 +369,9 @@ func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) {
rv = append(rv, d.Name)
} else {
rv = append(rv, d.mounts...)
- rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime))
+ if len(d.Runtime) > 0 {
+ rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime))
+ }
rv = append(rv, fmt.Sprintf("--name=%s", d.Name))
rv = append(rv, testutil.ImageByName(r.Image))
}
@@ -477,6 +539,56 @@ func (d *Docker) FindIP() (net.IP, error) {
return ip, nil
}
+// A NetworkInterface is container's network interface information.
+type NetworkInterface struct {
+ IPv4 net.IP
+ MAC net.HardwareAddr
+}
+
+// ListNetworks returns the network interfaces of the container, keyed by
+// Docker network name.
+func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) {
+ const format = `{{json .NetworkSettings.Networks}}`
+ out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err)
+ }
+
+ networks := map[string]map[string]string{}
+ if err := json.Unmarshal(out, &networks); err != nil {
+ return nil, fmt.Errorf("error decoding network interfaces: %w", err)
+ }
+
+ interfaces := map[string]NetworkInterface{}
+ for name, iface := range networks {
+ var netface NetworkInterface
+
+ rawIP := strings.TrimSpace(iface["IPAddress"])
+ if rawIP != "" {
+ ip := net.ParseIP(rawIP)
+ if ip == nil {
+ return nil, fmt.Errorf("invalid IP: %q", rawIP)
+ }
+ // Docker's IPAddress field is IPv4. The IPv6 address
+ // is stored in the GlobalIPv6Address field.
+ netface.IPv4 = ip
+ }
+
+ rawMAC := strings.TrimSpace(iface["MacAddress"])
+ if rawMAC != "" {
+ mac, err := net.ParseMAC(rawMAC)
+ if err != nil {
+ return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err)
+ }
+ netface.MAC = mac
+ }
+
+ interfaces[name] = netface
+ }
+
+ return interfaces, nil
+}
+
// SandboxPid returns the PID to the sandbox process.
func (d *Docker) SandboxPid() (int, error) {
out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput()
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index ff856883a..9cbb2ed5b 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -337,27 +337,53 @@ func TestJobControl(t *testing.T) {
}
}
-// TestTmpFile checks that files inside '/tmp' are not overridden. In addition,
-// it checks that working dir is created if it doesn't exit.
+// TestWorkingDirCreation checks that working dir is created if it doesn't exit.
+func TestWorkingDirCreation(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ workingDir string
+ }{
+ {name: "root", workingDir: "/foo"},
+ {name: "tmp", workingDir: "/tmp/foo"},
+ } {
+ for _, readonly := range []bool{true, false} {
+ name := tc.name
+ if readonly {
+ name += "-readonly"
+ }
+ t.Run(name, func(t *testing.T) {
+ d := dockerutil.MakeDocker(t)
+ defer d.CleanUp()
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ WorkDir: tc.workingDir,
+ ReadOnly: readonly,
+ }
+ got, err := d.Run(opts, "sh", "-c", "echo ${PWD}")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want := tc.workingDir + "\n"; want != got {
+ t.Errorf("invalid working dir, want: %q, got: %q", want, got)
+ }
+ })
+ }
+ }
+}
+
+// TestTmpFile checks that files inside '/tmp' are not overridden.
func TestTmpFile(t *testing.T) {
d := dockerutil.MakeDocker(t)
defer d.CleanUp()
- // Should work without ReadOnly
- if _, err := d.Run(dockerutil.RunOpts{
- Image: "basic/alpine",
- WorkDir: "/tmp/foo/bar",
- }, "touch", "/tmp/foo/bar/file"); err != nil {
+ opts := dockerutil.RunOpts{Image: "tmpfile"}
+ got, err := d.Run(opts, "cat", "/tmp/foo/file.txt")
+ if err != nil {
t.Fatalf("docker run failed: %v", err)
}
-
- // Expect failure.
- if _, err := d.Run(dockerutil.RunOpts{
- Image: "basic/alpine",
- WorkDir: "/tmp/foo/bar",
- ReadOnly: true,
- }, "touch", "/tmp/foo/bar/file"); err == nil {
- t.Fatalf("docker run expected failure, but succeeded")
+ if want := "123\n"; want != got {
+ t.Errorf("invalid file content, want: %q, got: %q", want, got)
}
}
diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md
index a82ad996a..f46c67a0c 100644
--- a/test/packetimpact/README.md
+++ b/test/packetimpact/README.md
@@ -18,6 +18,27 @@ Packetimpact aims to provide:
* **Control-flow** like for loops, conditionals, and variables.
* **Flexibilty** to specify every byte in a packet or use multiple sockets.
+## How to run packetimpact tests?
+
+Build the test container image by running the following at the root of the
+repository:
+
+```bash
+$ make load-packetimpact
+```
+
+Run a test, e.g. `fin_wait2_timeout`, against Linux:
+
+```bash
+$ bazel test //test/packetimpact/tests:fin_wait2_timeout_linux_test
+```
+
+Run the same test, but against gVisor:
+
+```bash
+$ bazel test //test/packetimpact/tests:fin_wait2_timeout_netstack_test
+```
+
## When to use packetimpact?
There are a few ways to write networking tests for gVisor currently:
diff --git a/test/packetimpact/netdevs/BUILD b/test/packetimpact/netdevs/BUILD
new file mode 100644
index 000000000..422bb9b0c
--- /dev/null
+++ b/test/packetimpact/netdevs/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "netdevs",
+ srcs = ["netdevs.go"],
+ visibility = ["//test/packetimpact:__subpackages__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go
new file mode 100644
index 000000000..d2c9cfeaf
--- /dev/null
+++ b/test/packetimpact/netdevs/netdevs.go
@@ -0,0 +1,104 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netdevs contains utilities for working with network devices.
+package netdevs
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// A DeviceInfo represents a network device.
+type DeviceInfo struct {
+ MAC net.HardwareAddr
+ IPv4Addr net.IP
+ IPv4Net *net.IPNet
+ IPv6Addr net.IP
+ IPv6Net *net.IPNet
+}
+
+var (
+ deviceLine = regexp.MustCompile(`^\s*\d+: (\w+)`)
+ linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`)
+ inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`)
+ inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`)
+)
+
+// ParseDevices parses the output from `ip addr show` into a map from device
+// name to information about the device.
+func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) {
+ var currentDevice string
+ var currentInfo DeviceInfo
+ deviceInfos := make(map[string]DeviceInfo)
+ for _, line := range strings.Split(cmdOutput, "\n") {
+ if m := deviceLine.FindStringSubmatch(line); m != nil {
+ if currentDevice != "" {
+ deviceInfos[currentDevice] = currentInfo
+ }
+ currentInfo = DeviceInfo{}
+ currentDevice = m[1]
+ } else if m := linkLine.FindStringSubmatch(line); m != nil {
+ mac, err := net.ParseMAC(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.MAC = mac
+ } else if m := inetLine.FindStringSubmatch(line); m != nil {
+ ipv4Addr, ipv4Net, err := net.ParseCIDR(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.IPv4Addr = ipv4Addr
+ currentInfo.IPv4Net = ipv4Net
+ } else if m := inet6Line.FindStringSubmatch(line); m != nil {
+ ipv6Addr, ipv6Net, err := net.ParseCIDR(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.IPv6Addr = ipv6Addr
+ currentInfo.IPv6Net = ipv6Net
+ }
+ }
+ if currentDevice != "" {
+ deviceInfos[currentDevice] = currentInfo
+ }
+ return deviceInfos, nil
+}
+
+// MACToIP converts the MAC address to an IPv6 link local address as described
+// in RFC 4291 page 20: https://tools.ietf.org/html/rfc4291#page-20
+func MACToIP(mac net.HardwareAddr) net.IP {
+ addr := make([]byte, header.IPv6AddressSize)
+ addr[0] = 0xfe
+ addr[1] = 0x80
+ header.EthernetAdddressToModifiedEUI64IntoBuf(tcpip.LinkAddress(mac), addr[8:])
+ return net.IP(addr)
+}
+
+// FindDeviceByIP finds a DeviceInfo and device name from an IP address in the
+// output of ParseDevices.
+func FindDeviceByIP(ip net.IP, devices map[string]DeviceInfo) (string, DeviceInfo, error) {
+ for dev, info := range devices {
+ if info.IPv4Addr.Equal(ip) {
+ return dev, info, nil
+ }
+ }
+ return "", DeviceInfo{}, fmt.Errorf("can't find %s on any interface", ip)
+}
diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD
new file mode 100644
index 000000000..0b68a760a
--- /dev/null
+++ b/test/packetimpact/runner/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+go_test(
+ name = "packetimpact_test",
+ srcs = ["packetimpact_test.go"],
+ tags = [
+ # Not intended to be run directly.
+ "local",
+ "manual",
+ ],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/packetimpact/netdevs",
+ ],
+)
diff --git a/test/packetimpact/tests/defs.bzl b/test/packetimpact/runner/defs.bzl
index 45dce64ab..ea66b9756 100644
--- a/test/packetimpact/tests/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -11,12 +11,10 @@ def _packetimpact_test_impl(ctx):
# permission problems, because all runfiles may not be owned by the
# current user, and no other users will be mapped in that namespace.
# Make sure that everything is readable here.
- "find . -type f -exec chmod a+rx {} \\;",
- "find . -type d -exec chmod a+rx {} \\;",
- "%s %s --posix_server_binary %s --testbench_binary %s $@\n" % (
+ "find . -type f -or -type d -exec chmod a+rx {} \\;",
+ "%s %s --testbench_binary %s $@\n" % (
test_runner.short_path,
" ".join(ctx.attr.flags),
- ctx.files._posix_server_binary[0].short_path,
ctx.files.testbench_binary[0].short_path,
),
])
@@ -38,7 +36,7 @@ _packetimpact_test = rule(
"_test_runner": attr.label(
executable = True,
cfg = "target",
- default = ":test_runner",
+ default = ":packetimpact_test",
),
"_posix_server_binary": attr.label(
cfg = "target",
@@ -69,6 +67,7 @@ def packetimpact_linux_test(
Args:
name: name of the test
testbench_binary: the testbench binary
+ expect_failure: the test must fail
**kwargs: all the other args, forwarded to _packetimpact_test
"""
expect_failure_flag = ["--expect_failure"] if expect_failure else []
@@ -113,8 +112,8 @@ def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure
name: name of the test
size: size of the test
pure: make a static go binary
- expect_linux_failure: expect the test to fail for Linux
- expect_netstack_failure: expect the test to fail for Netstack
+ expect_linux_failure: the test must fail for Linux
+ expect_netstack_failure: the test must fail for Netstack
**kwargs: all the other args, forwarded to go_test
"""
testbench_binary = name + "_test"
diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go
new file mode 100644
index 000000000..ac13c8543
--- /dev/null
+++ b/test/packetimpact/runner/packetimpact_test.go
@@ -0,0 +1,312 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// The runner starts docker containers and networking for a packetimpact test.
+package packetimpact_test
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "math/rand"
+ "net"
+ "path"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
+)
+
+// stringList implements flag.Value.
+type stringList []string
+
+// String implements flag.Value.String.
+func (l *stringList) String() string {
+ return strings.Join(*l, ",")
+}
+
+// Set implements flag.Value.Set.
+func (l *stringList) Set(value string) error {
+ *l = append(*l, value)
+ return nil
+}
+
+var (
+ dutPlatform = flag.String("dut_platform", "", "either \"linux\" or \"netstack\"")
+ testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary")
+ tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump")
+ extraTestArgs = stringList{}
+ expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run")
+
+ dutAddr = net.IPv4(0, 0, 0, 10)
+ testbenchAddr = net.IPv4(0, 0, 0, 20)
+)
+
+const ctrlPort = "40000"
+
+// logger implements testutil.Logger.
+//
+// Labels logs based on their source and formats multi-line logs.
+type logger string
+
+// Name implements testutil.Logger.Name.
+func (l logger) Name() string {
+ return string(l)
+}
+
+// Logf implements testutil.Logger.Logf.
+func (l logger) Logf(format string, args ...interface{}) {
+ lines := strings.Split(fmt.Sprintf(format, args...), "\n")
+ log.Printf("%s: %s", l, lines[0])
+ for _, line := range lines[1:] {
+ log.Printf("%*s %s", len(l), "", line)
+ }
+}
+
+func TestOne(t *testing.T) {
+ flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench")
+ flag.Parse()
+ if *dutPlatform != "linux" && *dutPlatform != "netstack" {
+ t.Fatal("--dut_platform should be either linux or netstack")
+ }
+ if *testbenchBinary == "" {
+ t.Fatal("--testbench_binary is missing")
+ }
+ if *dutPlatform == "netstack" {
+ if _, err := dockerutil.RuntimePath(); err != nil {
+ t.Fatal("--runtime is missing or invalid with --dut_platform=netstack:", err)
+ }
+ }
+ dockerutil.EnsureSupportedDockerVersion()
+
+ // Create the networks needed for the test. One control network is needed for
+ // the gRPC control packets and one test network on which to transmit the test
+ // packets.
+ ctrlNet := dockerutil.NewDockerNetwork(logger("ctrlNet"))
+ testNet := dockerutil.NewDockerNetwork(logger("testNet"))
+ for _, dn := range []*dockerutil.DockerNetwork{ctrlNet, testNet} {
+ for {
+ if err := createDockerNetwork(dn); err != nil {
+ t.Log("creating docker network:", err)
+ const wait = 100 * time.Millisecond
+ t.Logf("sleeping %s and will try creating docker network again", wait)
+ // This can fail if another docker network claimed the same IP so we'll
+ // just try again.
+ time.Sleep(wait)
+ continue
+ }
+ break
+ }
+ defer func(dn *dockerutil.DockerNetwork) {
+ if err := dn.Cleanup(); err != nil {
+ t.Errorf("unable to cleanup container %s: %s", dn.Name, err)
+ }
+ }(dn)
+ }
+
+ runOpts := dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ Extra: []string{"--sysctl", "net.ipv6.conf.all.disable_ipv6=0", "--rm"},
+ Foreground: true,
+ }
+
+ // Create the Docker container for the DUT.
+ dut := dockerutil.MakeDocker(logger("dut"))
+ if *dutPlatform == "linux" {
+ dut.Runtime = ""
+ }
+
+ const containerPosixServerBinary = "/packetimpact/posix_server"
+ dut.CopyFiles("/packetimpact", "/test/packetimpact/dut/posix_server")
+
+ if err := dut.Create(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort); err != nil {
+ t.Fatalf("unable to create container %s: %s", dut.Name, err)
+ }
+ defer dut.CleanUp()
+
+ // Add ctrlNet as eth1 and testNet as eth2.
+ const testNetDev = "eth2"
+ if err := addNetworks(dut, dutAddr, []*dockerutil.DockerNetwork{ctrlNet, testNet}); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := dut.Start(); err != nil {
+ t.Fatalf("unable to start container %s: %s", dut.Name, err)
+ }
+
+ if _, err := dut.WaitForOutput("Server listening.*\n", 60*time.Second); err != nil {
+ t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err)
+ }
+
+ dutTestDevice, dutDeviceInfo, err := deviceByIP(dut, addressInSubnet(dutAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ remoteMAC := dutDeviceInfo.MAC
+ remoteIPv6 := dutDeviceInfo.IPv6Addr
+ // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
+ // needed.
+ if remoteIPv6 == nil {
+ if _, err := dut.Exec(dockerutil.RunOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil {
+ t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err)
+ }
+ // Now try again, to make sure that it worked.
+ _, dutDeviceInfo, err = deviceByIP(dut, addressInSubnet(dutAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+ remoteIPv6 = dutDeviceInfo.IPv6Addr
+ if remoteIPv6 == nil {
+ t.Fatal("unable to set IPv6 address on container", dut.Name)
+ }
+ }
+
+ // Create the Docker container for the testbench.
+ testbench := dockerutil.MakeDocker(logger("testbench"))
+ testbench.Runtime = "" // The testbench always runs on Linux.
+
+ tbb := path.Base(*testbenchBinary)
+ containerTestbenchBinary := "/packetimpact/" + tbb
+ testbench.CopyFiles("/packetimpact", "/test/packetimpact/tests/"+tbb)
+
+ // Run tcpdump in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs := []string{
+ "tcpdump", "-S", "-vvv", "-U", "-n", "-i", testNetDev,
+ }
+ snifferRegex := "tcpdump: listening.*\n"
+ if *tshark {
+ // Run tshark in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs = []string{
+ "tshark", "-V", "-l", "-n", "-i", testNetDev,
+ "-o", "tcp.check_checksum:TRUE",
+ "-o", "udp.check_checksum:TRUE",
+ }
+ snifferRegex = "Capturing on.*\n"
+ }
+
+ if err := testbench.Create(runOpts, snifferArgs...); err != nil {
+ t.Fatalf("unable to create container %s: %s", testbench.Name, err)
+ }
+ defer testbench.CleanUp()
+
+ // Add ctrlNet as eth1 and testNet as eth2.
+ if err := addNetworks(testbench, testbenchAddr, []*dockerutil.DockerNetwork{ctrlNet, testNet}); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := testbench.Start(); err != nil {
+ t.Fatalf("unable to start container %s: %s", testbench.Name, err)
+ }
+
+ // Kill so that it will flush output.
+ defer testbench.Exec(dockerutil.RunOpts{}, "killall", snifferArgs[0])
+
+ if _, err := testbench.WaitForOutput(snifferRegex, 60*time.Second); err != nil {
+ t.Fatalf("sniffer on %s never listened: %s", dut.Name, err)
+ }
+
+ // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it
+ // will issue a RST. To prevent this IPtables can be used to filter out all
+ // incoming packets. The raw socket that packetimpact tests use will still see
+ // everything.
+ if _, err := testbench.Exec(dockerutil.RunOpts{}, "iptables", "-A", "INPUT", "-i", testNetDev, "-j", "DROP"); err != nil {
+ t.Fatalf("unable to Exec iptables on container %s: %s", testbench.Name, err)
+ }
+
+ // FIXME(b/156449515): Some piece of the system has a race. The old
+ // bash script version had a sleep, so we have one too. The race should
+ // be fixed and this sleep removed.
+ time.Sleep(time.Second)
+
+ // Start a packetimpact test on the test bench. The packetimpact test sends
+ // and receives packets and also sends POSIX socket commands to the
+ // posix_server to be executed on the DUT.
+ testArgs := []string{containerTestbenchBinary}
+ testArgs = append(testArgs, extraTestArgs...)
+ testArgs = append(testArgs,
+ "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(),
+ "--posix_server_port", ctrlPort,
+ "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(),
+ "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(),
+ "--remote_ipv6", remoteIPv6.String(),
+ "--remote_mac", remoteMAC.String(),
+ "--device", testNetDev,
+ )
+ _, err = testbench.Exec(dockerutil.RunOpts{}, testArgs...)
+ if !*expectFailure && err != nil {
+ t.Fatal("test failed:", err)
+ }
+ if *expectFailure && err == nil {
+ t.Fatal("test failure expected but the test succeeded, enable the test and mark the corresponding bug as fixed")
+ }
+}
+
+func addNetworks(d *dockerutil.Docker, addr net.IP, networks []*dockerutil.DockerNetwork) error {
+ for _, dn := range networks {
+ ip := addressInSubnet(addr, *dn.Subnet)
+ // Connect to the network with the specified IP address.
+ if err := dn.Connect(d, "--ip", ip.String()); err != nil {
+ return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err)
+ }
+ }
+ return nil
+}
+
+// addressInSubnet combines the subnet provided with the address and returns a
+// new address. The return address bits come from the subnet where the mask is 1
+// and from the ip address where the mask is 0.
+func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP {
+ var octets []byte
+ for i := 0; i < 4; i++ {
+ octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i])))
+ }
+ return net.IP(octets)
+}
+
+// makeDockerNetwork makes a randomly-named network that will start with the
+// namePrefix. The network will be a random /24 subnet.
+func createDockerNetwork(n *dockerutil.DockerNetwork) error {
+ randSource := rand.NewSource(time.Now().UnixNano())
+ r1 := rand.New(randSource)
+ // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0)
+ n.Subnet = &net.IPNet{
+ IP: ip,
+ Mask: ip.DefaultMask(),
+ }
+ return n.Create()
+}
+
+// deviceByIP finds a deviceInfo and device name from an IP address.
+func deviceByIP(d *dockerutil.Docker, ip net.IP) (string, netdevs.DeviceInfo, error) {
+ out, err := d.Exec(dockerutil.RunOpts{}, "ip", "addr", "show")
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err)
+ }
+ devs, err := netdevs.ParseDevices(out)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err)
+ }
+ testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err)
+ }
+ return testDevice, deviceInfo, nil
+}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 682933067..d19ec07d4 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
"//pkg/usermem",
+ "//test/packetimpact/netdevs",
"//test/packetimpact/proto:posix_server_go_proto",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 463fd0556..bf104e5ca 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -114,12 +114,12 @@ var _ layerState = (*etherState)(nil)
func newEtherState(out, in Ether) (*etherState, error) {
lMAC, err := tcpip.ParseMACAddress(LocalMAC)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("parsing local MAC: %q: %w", LocalMAC, err)
}
rMAC, err := tcpip.ParseMACAddress(RemoteMAC)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("parsing remote MAC: %q: %w", RemoteMAC, err)
}
s := etherState{
out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index a78b7d7ee..b919a3c2e 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -16,6 +16,7 @@ package testbench
import (
"context"
+ "flag"
"net"
"strconv"
"syscall"
@@ -37,6 +38,11 @@ type DUT struct {
// NewDUT creates a new connection with the DUT over gRPC.
func NewDUT(t *testing.T) DUT {
+ flag.Parse()
+ if err := genPseudoFlags(); err != nil {
+ t.Fatal("generating psuedo flags:", err)
+ }
+
posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort)
conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive}))
if err != nil {
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 4665f60b2..278229b7e 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -16,7 +16,6 @@ package testbench
import (
"encoding/binary"
- "flag"
"fmt"
"math"
"net"
@@ -41,7 +40,6 @@ func htons(x uint16) uint16 {
// NewSniffer creates a Sniffer connected to *device.
func NewSniffer(t *testing.T) (Sniffer, error) {
- flag.Parse()
snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
if err != nil {
return Sniffer{}, err
@@ -136,7 +134,6 @@ type Injector struct {
// NewInjector creates a new injector on *device.
func NewInjector(t *testing.T) (Injector, error) {
- flag.Parse()
ifInfo, err := net.InterfaceByName(Device)
if err != nil {
return Injector{}, err
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
index a1242b189..4de2aa1d3 100644
--- a/test/packetimpact/testbench/testbench.go
+++ b/test/packetimpact/testbench/testbench.go
@@ -16,7 +16,12 @@ package testbench
import (
"flag"
+ "fmt"
+ "net"
+ "os/exec"
"time"
+
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
)
var (
@@ -55,9 +60,31 @@ func RegisterFlags(fs *flag.FlagSet) {
fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets")
fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets")
- fs.StringVar(&LocalIPv6, "local_ipv6", LocalIPv6, "local IPv6 address for test packets")
fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets")
- fs.StringVar(&LocalMAC, "local_mac", LocalMAC, "local mac address for test packets")
fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets")
fs.StringVar(&Device, "device", Device, "local device for test packets")
}
+
+// genPseudoFlags populates flag-like global config based on real flags.
+//
+// genPseudoFlags must only be called after flag.Parse.
+func genPseudoFlags() error {
+ out, err := exec.Command("ip", "addr", "show").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("listing devices: %q: %w", string(out), err)
+ }
+ devs, err := netdevs.ParseDevices(string(out))
+ if err != nil {
+ return fmt.Errorf("parsing devices: %w", err)
+ }
+
+ _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(LocalIPv4), devs)
+ if err != nil {
+ return fmt.Errorf("can't find deviceInfo: %w", err)
+ }
+
+ LocalMAC = deviceInfo.MAC.String()
+ LocalIPv6 = deviceInfo.IPv6Addr.String()
+
+ return nil
+}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index c4ffda17e..3a0e9cb07 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -1,4 +1,4 @@
-load("defs.bzl", "packetimpact_go_test")
+load("//test/packetimpact/runner:defs.bzl", "packetimpact_go_test")
package(
default_visibility = ["//test/packetimpact:__subpackages__"],
@@ -177,8 +177,3 @@ packetimpact_go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
-
-sh_binary(
- name = "test_runner",
- srcs = ["test_runner.sh"],
-)
diff --git a/test/packetimpact/tests/test_runner.sh b/test/packetimpact/tests/test_runner.sh
deleted file mode 100755
index 706441cce..000000000
--- a/test/packetimpact/tests/test_runner.sh
+++ /dev/null
@@ -1,325 +0,0 @@
-#!/bin/bash
-
-# Copyright 2020 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Run a packetimpact test. Two docker containers are made, one for the
-# Device-Under-Test (DUT) and one for the test bench. Each is attached with
-# two networks, one for control packets that aid the test and one for test
-# packets which are sent as part of the test and observed for correctness.
-
-set -euxo pipefail
-
-function failure() {
- local lineno=$1
- local msg=$2
- local filename="$0"
- echo "FAIL: $filename:$lineno: $msg"
-}
-trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
-
-declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark,extra_test_arg:,expect_failure"
-
-# Don't use declare below so that the error from getopt will end the script.
-PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
-
-eval set -- "$PARSED"
-
-declare -a EXTRA_TEST_ARGS
-
-while true; do
- case "$1" in
- --dut_platform)
- # Either "linux" or "netstack".
- declare -r DUT_PLATFORM="$2"
- shift 2
- ;;
- --posix_server_binary)
- declare -r POSIX_SERVER_BINARY="$2"
- shift 2
- ;;
- --testbench_binary)
- declare -r TESTBENCH_BINARY="$2"
- shift 2
- ;;
- --runtime)
- # Not readonly because there might be multiple --runtime arguments and we
- # want to use just the last one. Only used if --dut_platform is
- # "netstack".
- declare RUNTIME="$2"
- shift 2
- ;;
- --tshark)
- declare -r TSHARK="1"
- shift 1
- ;;
- --extra_test_arg)
- EXTRA_TEST_ARGS+="$2"
- shift 2
- ;;
- --expect_failure)
- declare -r EXPECT_FAILURE="1"
- shift 1
- ;;
- --)
- shift
- break
- ;;
- *)
- echo "Programming error"
- exit 3
- esac
-done
-
-# All the other arguments are scripts.
-declare -r scripts="$@"
-
-# Check that the required flags are defined in a way that is safe for "set -u".
-if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then
- if [[ -z "${RUNTIME-}" ]]; then
- echo "FAIL: Missing --runtime argument: ${RUNTIME-}"
- exit 2
- fi
- declare -r RUNTIME_ARG="--runtime ${RUNTIME}"
-elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then
- declare -r RUNTIME_ARG=""
-else
- echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}"
- exit 2
-fi
-if [[ ! -f "${POSIX_SERVER_BINARY-}" ]]; then
- echo "FAIL: Bad or missing --posix_server_binary: ${POSIX_SERVER-}"
- exit 2
-fi
-if [[ ! -f "${TESTBENCH_BINARY-}" ]]; then
- echo "FAIL: Bad or missing --testbench_binary: ${TESTBENCH_BINARY-}"
- exit 2
-fi
-
-function new_net_prefix() {
- # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
- echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)"
-}
-
-# Variables specific to the control network and interface start with CTRL_.
-# Variables specific to the test network and interface start with TEST_.
-# Variables specific to the DUT start with DUT_.
-# Variables specific to the test bench start with TESTBENCH_.
-# Use random numbers so that test networks don't collide.
-declare CTRL_NET="ctrl_net-${RANDOM}${RANDOM}"
-declare CTRL_NET_PREFIX=$(new_net_prefix)
-declare TEST_NET="test_net-${RANDOM}${RANDOM}"
-declare TEST_NET_PREFIX=$(new_net_prefix)
-# On both DUT and test bench, testing packets are on the eth2 interface.
-declare -r TEST_DEVICE="eth2"
-# Number of bits in the *_NET_PREFIX variables.
-declare -r NET_MASK="24"
-# Last bits of the DUT's IP address.
-declare -r DUT_NET_SUFFIX=".10"
-# Control port.
-declare -r CTRL_PORT="40000"
-# Last bits of the test bench's IP address.
-declare -r TESTBENCH_NET_SUFFIX=".20"
-declare -r TIMEOUT="60"
-declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetimpact"
-
-# Make sure that docker is installed.
-docker --version
-
-function finish {
- local cleanup_success=1
-
- if [[ -z "${TSHARK-}" ]]; then
- # Kill tcpdump so that it will flush output.
- docker exec -t "${TESTBENCH}" \
- killall tcpdump || \
- cleanup_success=0
- else
- # Kill tshark so that it will flush output.
- docker exec -t "${TESTBENCH}" \
- killall tshark || \
- cleanup_success=0
- fi
-
- for net in "${CTRL_NET}" "${TEST_NET}"; do
- # Kill all processes attached to ${net}.
- for docker_command in "kill" "rm"; do
- (docker network inspect "${net}" \
- --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \
- | xargs -r docker "${docker_command}") || \
- cleanup_success=0
- done
- # Remove the network.
- docker network rm "${net}" || \
- cleanup_success=0
- done
-
- if ((!$cleanup_success)); then
- echo "FAIL: Cleanup command failed"
- exit 4
- fi
-}
-trap finish EXIT
-
-# Subnet for control packets between test bench and DUT.
-while ! docker network create \
- "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do
- sleep 0.1
- CTRL_NET_PREFIX=$(new_net_prefix)
- CTRL_NET="ctrl_net-${RANDOM}${RANDOM}"
-done
-
-# Subnet for the packets that are part of the test.
-while ! docker network create \
- "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do
- sleep 0.1
- TEST_NET_PREFIX=$(new_net_prefix)
- TEST_NET="test_net-${RANDOM}${RANDOM}"
-done
-
-docker pull "${IMAGE_TAG}"
-
-# Create the DUT container and connect to network.
-DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \
- --cap-add NET_ADMIN \
- --sysctl net.ipv6.conf.all.disable_ipv6=0 \
- --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
-docker network connect "${CTRL_NET}" \
- --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
- || (docker kill ${DUT}; docker rm ${DUT}; false)
-docker network connect "${TEST_NET}" \
- --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
- || (docker kill ${DUT}; docker rm ${DUT}; false)
-docker start "${DUT}"
-
-# Create the test bench container and connect to network.
-TESTBENCH=$(docker create --privileged --rm \
- --cap-add NET_ADMIN \
- --sysctl net.ipv6.conf.all.disable_ipv6=0 \
- --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
-docker network connect "${CTRL_NET}" \
- --ip "${CTRL_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \
- || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false)
-docker network connect "${TEST_NET}" \
- --ip "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \
- || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false)
-docker start "${TESTBENCH}"
-
-# Start the posix_server in the DUT.
-declare -r DOCKER_POSIX_SERVER_BINARY="/$(basename ${POSIX_SERVER_BINARY})"
-docker cp -L ${POSIX_SERVER_BINARY} "${DUT}:${DOCKER_POSIX_SERVER_BINARY}"
-
-docker exec -t "${DUT}" \
- /bin/bash -c "${DOCKER_POSIX_SERVER_BINARY} \
- --ip ${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \
- --port ${CTRL_PORT}" &
-
-# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will
-# issue a RST. To prevent this IPtables can be used to filter those out.
-docker exec "${TESTBENCH}" \
- iptables -A INPUT -i ${TEST_DEVICE} -j DROP
-
-# Wait for the DUT server to come up. Attempt to connect to it from the test
-# bench every 100 milliseconds until success.
-while ! docker exec "${TESTBENCH}" \
- nc -zv "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${CTRL_PORT}"; do
- sleep 0.1
-done
-
-declare -r REMOTE_MAC=$(docker exec -t "${DUT}" ip link show \
- "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6)
-declare -r LOCAL_MAC=$(docker exec -t "${TESTBENCH}" ip link show \
- "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6)
-declare REMOTE_IPV6=$(docker exec -t "${DUT}" ip addr show scope link \
- "${TEST_DEVICE}" | grep inet6 | cut -d' ' -f6 | cut -d'/' -f1)
-declare -r LOCAL_IPV6=$(docker exec -t "${TESTBENCH}" ip addr show scope link \
- "${TEST_DEVICE}" | grep inet6 | cut -d' ' -f6 | cut -d'/' -f1)
-
-# Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
-# needed. Convert the MAC address to an IPv6 link local address as described in
-# RFC 4291 page 20: https://tools.ietf.org/html/rfc4291#page-20
-if [[ -z "${REMOTE_IPV6}" ]]; then
- # Split the octets of the MAC into an array of strings.
- IFS=":" read -a REMOTE_OCTETS <<< "${REMOTE_MAC}"
- # Flip the global bit.
- REMOTE_OCTETS[0]=$(printf '%x' "$((0x${REMOTE_OCTETS[0]} ^ 2))")
- # Add the IPv6 address.
- docker exec "${DUT}" \
- ip addr add $(printf 'fe80::%02x%02x:%02xff:fe%02x:%02x%02x/64' \
- "0x${REMOTE_OCTETS[0]}" "0x${REMOTE_OCTETS[1]}" "0x${REMOTE_OCTETS[2]}" \
- "0x${REMOTE_OCTETS[3]}" "0x${REMOTE_OCTETS[4]}" "0x${REMOTE_OCTETS[5]}") \
- scope link \
- dev "${TEST_DEVICE}"
- # Re-extract the IPv6 address.
- # TODO(eyalsoha): Add "scope link" below when netstack supports correctly
- # creating link-local IPv6 addresses.
- REMOTE_IPV6=$(docker exec -t "${DUT}" ip addr show \
- "${TEST_DEVICE}" | grep inet6 | cut -d' ' -f6 | cut -d'/' -f1)
-fi
-
-declare -r DOCKER_TESTBENCH_BINARY="/$(basename ${TESTBENCH_BINARY})"
-docker cp -L "${TESTBENCH_BINARY}" "${TESTBENCH}:${DOCKER_TESTBENCH_BINARY}"
-
-if [[ -z "${TSHARK-}" ]]; then
- # Run tcpdump in the test bench unbuffered, without dns resolution, just on
- # the interface with the test packets.
- docker exec -t "${TESTBENCH}" \
- tcpdump -S -vvv -U -n -i "${TEST_DEVICE}" \
- net "${TEST_NET_PREFIX}/24" or \
- host "${REMOTE_IPV6}" or \
- host "${LOCAL_IPV6}" &
-else
- # Run tshark in the test bench unbuffered, without dns resolution, just on the
- # interface with the test packets.
- docker exec -t "${TESTBENCH}" \
- tshark -V -l -n -i "${TEST_DEVICE}" \
- -o tcp.check_checksum:TRUE \
- -o udp.check_checksum:TRUE \
- net "${TEST_NET_PREFIX}/24" or \
- host "${REMOTE_IPV6}" or \
- host "${LOCAL_IPV6}" &
-fi
-
-# tcpdump and tshark take time to startup
-sleep 3
-
-# Start a packetimpact test on the test bench. The packetimpact test sends and
-# receives packets and also sends POSIX socket commands to the posix_server to
-# be executed on the DUT.
-docker exec \
- -e XML_OUTPUT_FILE="/test.xml" \
- -e TEST_TARGET \
- -t "${TESTBENCH}" \
- /bin/bash -c "${DOCKER_TESTBENCH_BINARY} \
- ${EXTRA_TEST_ARGS[@]-} \
- --posix_server_ip=${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \
- --posix_server_port=${CTRL_PORT} \
- --remote_ipv4=${TEST_NET_PREFIX}${DUT_NET_SUFFIX} \
- --local_ipv4=${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX} \
- --remote_ipv6=${REMOTE_IPV6} \
- --local_ipv6=${LOCAL_IPV6} \
- --remote_mac=${REMOTE_MAC} \
- --local_mac=${LOCAL_MAC} \
- --device=${TEST_DEVICE}" && true
-declare -r TEST_RESULT="${?}"
-if [[ -z "${EXPECT_FAILURE-}" && "${TEST_RESULT}" != 0 ]]; then
- echo 'FAIL: This test was expected to pass.'
- exit ${TEST_RESULT}
-fi
-if [[ ! -z "${EXPECT_FAILURE-}" && "${TEST_RESULT}" == 0 ]]; then
- echo 'FAIL: This test was expected to fail but passed. Enable the test and' \
- 'mark the corresponding bug as fixed.'
- exit 1
-fi
-echo PASS: No errors.
diff --git a/test/runner/runner.go b/test/runner/runner.go
index 14c9cbc47..e4f04cd2a 100644
--- a/test/runner/runner.go
+++ b/test/runner/runner.go
@@ -341,11 +341,13 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
}
}
- // Set environment variables that indicate we are
- // running in gVisor with the given platform and network.
+ // Set environment variables that indicate we are running in gVisor with
+ // the given platform, network, and filesystem stack.
+ // TODO(gvisor.dev/issue/1487): Update this when the runner supports VFS2.
platformVar := "TEST_ON_GVISOR"
networkVar := "GVISOR_NETWORK"
- env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network)
+ vfsVar := "GVISOR_VFS"
+ env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network, vfsVar+"=VFS1")
// Remove env variables that cause the gunit binary to write output
// files, since they will stomp on eachother, and on the output files
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index afa59c1da..e0a4d0985 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -62,9 +62,7 @@ TEST(SocketTest, ProtocolInet) {
}
TEST(SocketTest, UnixSocketStat) {
- // TODO(gvisor.dev/issue/1624): Re-enable this test once VFS1 is deleted. It
- // should pass in VFS2.
- SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(IsRunningWithVFS1());
FileDescriptor bound =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
@@ -94,9 +92,7 @@ TEST(SocketTest, UnixSocketStat) {
}
TEST(SocketTest, UnixConnectNeedsWritePerm) {
- // TODO(gvisor.dev/issue/1624): Re-enable this test once VFS1 is deleted. It
- // should succeed in VFS2.
- SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(IsRunningWithVFS1());
FileDescriptor bound =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
@@ -128,10 +124,7 @@ using SocketOpenTest = ::testing::TestWithParam<int>;
// UDS cannot be opened.
TEST_P(SocketOpenTest, Unix) {
// FIXME(b/142001530): Open incorrectly succeeds on gVisor.
- //
- // TODO(gvisor.dev/issue/1624): Re-enable this test once VFS1 is deleted. It
- // should succeed in VFS2.
- SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(IsRunningWithVFS1());
FileDescriptor bound =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index f103e2e56..08fc4b1b7 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -430,6 +430,55 @@ TEST(SpliceTest, TwoPipes) {
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
}
+TEST(SpliceTest, TwoPipesCircular) {
+ // This test deadlocks the sentry on VFS1 because VFS1 splice ordering is
+ // based on fs.File.UniqueID, which does not prevent circular ordering between
+ // e.g. inode-level locks taken by fs.FileOperations.
+ SKIP_IF(IsRunningWithVFS1());
+
+ // Create two pipes.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor first_rfd(fds[0]);
+ const FileDescriptor first_wfd(fds[1]);
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor second_rfd(fds[0]);
+ const FileDescriptor second_wfd(fds[1]);
+
+ // On Linux, each pipe is normally limited to
+ // include/linux/pipe_fs_i.h:PIPE_DEF_BUFFERS buffers worth of data.
+ constexpr size_t PIPE_DEF_BUFFERS = 16;
+
+ // Write some data to each pipe. Below we splice 1 byte at a time between
+ // pipes, which very quickly causes each byte to be stored in a separate
+ // buffer, so we must ensure that the total amount of data in the system is <=
+ // PIPE_DEF_BUFFERS bytes.
+ std::vector<char> buf(PIPE_DEF_BUFFERS / 2);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Have another thread splice from the second pipe to the first, while we
+ // splice from the first to the second. The test passes if this does not
+ // deadlock.
+ const int kIterations = 1000;
+ DisableSave ds;
+ ScopedThread t([&]() {
+ for (int i = 0; i < kIterations; i++) {
+ ASSERT_THAT(
+ splice(second_rfd.get(), nullptr, first_wfd.get(), nullptr, 1, 0),
+ SyscallSucceedsWithValue(1));
+ }
+ });
+ for (int i = 0; i < kIterations; i++) {
+ ASSERT_THAT(
+ splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr, 1, 0),
+ SyscallSucceedsWithValue(1));
+ }
+}
+
TEST(SpliceTest, Blocking) {
// Create two new pipes.
int first[2], second[2];
diff --git a/test/util/test_util.cc b/test/util/test_util.cc
index 95e1e0c96..b20758626 100644
--- a/test/util/test_util.cc
+++ b/test/util/test_util.cc
@@ -42,12 +42,13 @@ namespace testing {
#define TEST_ON_GVISOR "TEST_ON_GVISOR"
#define GVISOR_NETWORK "GVISOR_NETWORK"
+#define GVISOR_VFS "GVISOR_VFS"
bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; }
const std::string GvisorPlatform() {
// Set by runner.go.
- char* env = getenv(TEST_ON_GVISOR);
+ const char* env = getenv(TEST_ON_GVISOR);
if (!env) {
return Platform::kNative;
}
@@ -55,10 +56,19 @@ const std::string GvisorPlatform() {
}
bool IsRunningWithHostinet() {
- char* env = getenv(GVISOR_NETWORK);
+ const char* env = getenv(GVISOR_NETWORK);
return env && strcmp(env, "host") == 0;
}
+bool IsRunningWithVFS1() {
+ const char* env = getenv(GVISOR_VFS);
+ if (env == nullptr) {
+ // If not set, it's running on Linux.
+ return false;
+ }
+ return strcmp(env, "VFS1") == 0;
+}
+
// Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations
// %ebx contains the address of the global offset table. %rbx is occasionally
// used to address stack variables in presence of dynamic allocas.
diff --git a/test/util/test_util.h b/test/util/test_util.h
index c5cb9d6d6..8e3245b27 100644
--- a/test/util/test_util.h
+++ b/test/util/test_util.h
@@ -220,6 +220,7 @@ constexpr char kKVM[] = "kvm";
bool IsRunningOnGvisor();
const std::string GvisorPlatform();
bool IsRunningWithHostinet();
+bool IsRunningWithVFS1();
#ifdef __linux__
void SetupGvisorDeathTest();
diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go
index e9cc2c753..0860ca9db 100644
--- a/tools/go_generics/generics.go
+++ b/tools/go_generics/generics.go
@@ -223,7 +223,9 @@ func main() {
} else {
switch kind {
case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction:
- ident.Name = *prefix + ident.Name + *suffix
+ if ident.Name != "_" {
+ ident.Name = *prefix + ident.Name + *suffix
+ }
case globals.KindTag:
// Modify the state tag appropriately.
if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil {