diff options
38 files changed, 1489 insertions, 302 deletions
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md index 0de2b9932..9afdd264d 100644 --- a/g3doc/user_guide/install.md +++ b/g3doc/user_guide/install.md @@ -150,11 +150,8 @@ users, and ensure it is executable by all users**, since `runsc` executes itself as user `nobody` to avoid unnecessary privileges. The `/usr/local/bin` directory is a good place to put the `runsc` binary. -After installation, the`runsc` binary comes with an `install` command that can -optionally automatically configure Docker: - -```bash -runsc install -``` +After installation, try out `runsc` by following the +[Docker Quick Start](./quick_start/docker.md) or +[OCI Quick Start](./quick_start/oci.md). [releases]: https://github.com/google/gvisor/releases diff --git a/g3doc/user_guide/quick_start/docker.md b/g3doc/user_guide/quick_start/docker.md index fa8b9076b..6ad594ecc 100644 --- a/g3doc/user_guide/quick_start/docker.md +++ b/g3doc/user_guide/quick_start/docker.md @@ -14,24 +14,28 @@ the next section and proceed straight to running a container. ## Configuring Docker First you will need to configure Docker to use `runsc` by adding a runtime entry -to your Docker configuration (`/etc/docker/daemon.json`). You may have to create -this file if it does not exist. Also, some Docker versions also require you to -[specify the `storage-driver` field][storage-driver]. - -In the end, the file should look something like: - -```json -{ - "runtimes": { - "runsc": { - "path": "/usr/local/bin/runsc" - } - } -} +to your Docker configuration (e.g. `/etc/docker/daemon.json`). The easiest way +to this is via the `runsc install` command. This will install a docker runtime +named "runsc" by default. + +```bash +sudo runsc install +``` + +You may also wish to install a runtime entry for debugging. The `runsc install` +command can accept options that will be passed to the runtime when it is invoked +by Docker. + +```bash +sudo runsc install --runtime runsc-debug -- \ + --debug \ + --debug-log=/tmp/runsc-debug.log \ + --strace \ + --log-packets ``` -You must restart the Docker daemon after making changes to this file, typically -this is done via `systemd`: +You must restart the Docker daemon after installing the runtime. Typically this +is done via `systemd`: ```bash sudo systemctl restart docker diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index ea8d2422c..7a82631c5 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "goid.go", "goid_amd64.s", + "goid_arm64.s", "goid_race.go", "goid_unsafe.go", ], diff --git a/pkg/sentry/fsimpl/gofer/pagemath.go b/pkg/goid/goid_arm64.s index 847cb0784..a7465b75d 100644 --- a/pkg/sentry/fsimpl/gofer/pagemath.go +++ b/pkg/goid/goid_arm64.s @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,20 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package gofer +#include "textflag.h" -import ( - "gvisor.dev/gvisor/pkg/usermem" -) - -// This are equivalent to usermem.Addr.RoundDown/Up, but without the -// potentially truncating conversion to usermem.Addr. This is necessary because -// there is no way to define generic "PageRoundDown/Up" functions in Go. - -func pageRoundDown(x uint64) uint64 { - return x &^ (usermem.PageSize - 1) -} - -func pageRoundUp(x uint64) uint64 { - return pageRoundDown(x + usermem.PageSize - 1) -} +// func getg() *g +TEXT ·getg(SB),NOSPLIT,$0-8 + MOVD g, R0 // g + MOVD R0, ret+0(FP) + RET diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index 41bf104d0..f84d03700 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -5,6 +5,8 @@ package(licenses = ["notice"]) go_library( name = "linewriter", srcs = ["linewriter.go"], + marshal = False, + stateify = False, visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index a7c8f7bef..3ed6aba5c 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -10,6 +10,8 @@ go_library( "json_k8s.go", "log.go", ], + marshal = False, + stateify = False, visibility = [ "//visibility:public", ], diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD index 1b487b887..f57ccc170 100644 --- a/pkg/segment/BUILD +++ b/pkg/segment/BUILD @@ -21,6 +21,8 @@ go_template( ], opt_consts = [ "minDegree", + # trackGaps must either be 0 or 1. + "trackGaps", ], types = [ "Key", diff --git a/pkg/segment/set.go b/pkg/segment/set.go index 03e4f258f..1a17ad9cb 100644 --- a/pkg/segment/set.go +++ b/pkg/segment/set.go @@ -36,6 +36,34 @@ type Range interface{} // Value is a required type parameter. type Value interface{} +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const trackGaps = 0 + +var _ = uint8(trackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type dynamicGap [trackGaps]Key + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Get() Key { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Set(v Key) { + d[:][0] = v +} + // Functions is a required type parameter that must be a struct implementing // the methods defined by Functions. type Functions interface { @@ -327,8 +355,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if prev.Ok() && prev.End() == r.Start { if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() prev.SetEndUnchecked(r.End) prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } if next.Ok() && next.Start() == r.End { val = mval if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { @@ -342,11 +374,16 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if next.Ok() && next.Start() == r.End { if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() next.SetStartUnchecked(r.Start) next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } return next } } + // InsertWithoutMergingUnchecked will maintain maxGap if necessary. return s.InsertWithoutMergingUnchecked(gap, r, val) } @@ -373,11 +410,15 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) gap.node.keys[gap.index] = r gap.node.values[gap.index] = val gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } return Iterator{gap.node, gap.index} } @@ -399,12 +440,23 @@ func (s *Set) Remove(seg Iterator) GapIterator { // overlap. seg.SetRangeUnchecked(victim.Range()) seg.SetValue(victim.Value()) + // Need to update the nextAdjacentNode's maxGap because the gap in between + // must have been modified by updating seg.Range() to victim.Range(). + // seg.NextSegment() must exist since the last segment can't be in a + // non-leaf node. + nextAdjacentNode := seg.NextSegment().node + if trackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } return s.Remove(victim).NextGap() } copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) seg.node.nrSegments-- + if trackGaps != 0 { + seg.node.updateMaxGapLeaf() + } return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index}) } @@ -455,6 +507,7 @@ func (s *Set) MergeUnchecked(first, second Iterator) Iterator { // overlaps second. first.SetEndUnchecked(second.End()) first.SetValue(mval) + // Remove will handle the maxGap update if necessary. return s.Remove(second).PrevSegment() } } @@ -631,6 +684,12 @@ type node struct { // than "isLeaf" because false must be the correct value for an empty root. hasChildren bool + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap dynamicGap + // Nodes store keys and values in separate arrays to maximize locality in // the common case (scanning keys for lookup). keys [maxDegree - 1]Range @@ -676,12 +735,12 @@ func (n *node) nextSibling() *node { // required for insertion, and returns an updated iterator to the position // represented by gap. func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { - if n.parent != nil { - gap = n.parent.rebalanceBeforeInsert(gap) - } if n.nrSegments < maxDegree-1 { return gap } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } if n.parent == nil { // n is root. Move all segments before and after n's median segment // into new child nodes adjacent to the median segment, which is now @@ -719,6 +778,13 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { n.hasChildren = true n.children[0] = left n.children[1] = right + // In this case, n's maxGap won't violated as it's still the root, + // but the left and right children should be updated locally as they + // are newly split from n. + if trackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } if gap.node != n { return gap } @@ -758,6 +824,12 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { } } n.nrSegments = minDegree - 1 + // MaxGap of n's parent is not violated because the segments within is not changed. + // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } // gap.node can't be n.parent because gaps are always in leaf nodes. if gap.node != n { return gap @@ -821,6 +893,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling && gap.index == sibling.nrSegments { return GapIterator{n, 0} } @@ -849,6 +927,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling { if gap.index == 0 { return GapIterator{n, n.nrSegments} @@ -886,6 +970,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { p.children[0] = nil p.children[1] = nil } + // No need to update maxGap of p as its content is not changed. if gap.node == left { return GapIterator{p, gap.index} } @@ -932,11 +1017,152 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } p.children[p.nrSegments] = nil p.nrSegments-- + // Update maxGap of left locally, no need to change p and right because + // p's contents is not changed and right is already invalid. + if trackGaps != 0 { + left.updateMaxGapLocal() + } // This process robs p of one segment, so recurse into rebalancing p. n = p } } +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *node) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + // If new max equals the old maxGap, no update is needed. + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + // Grow ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + // p and its ancestors already contain an equal or larger gap. + break + } + // Only if new maxGap is larger than parent's + // old maxGap, propagate this update to parent. + p.maxGap.Set(max) + } + return + } + // Shrink ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + // p and its ancestors still contain a larger gap. + break + } + // If new max is smaller than the old maxGap, and this gap used + // to be the maxGap of its parent, iterate parent's children + // and calculate parent's new maxGap.(It's probable that parent + // has two children with the old maxGap, but we need to check it anyway.) + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + // p and its ancestors still contain a gap of at least equal size. + break + } + // If p's new maxGap differs from the old one, propagate this update. + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *node) updateMaxGapLocal() { + if !n.hasChildren { + // Leaf node iterates its gaps. + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + // Non-leaf node iterates its children. + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *node) calculateMaxGapLeaf() Key { + max := GapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (GapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *node) calculateMaxGapInternal() Key { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + // A Iterator is conceptually one of: // // - A pointer to a segment in a set; or @@ -1243,6 +1469,122 @@ func (gap GapIterator) NextGap() GapIterator { return seg.NextGap() } +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the trailing one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate subsequent gaps. + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + // If gap is the first gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the first one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate previous gaps. + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + // If gap is the first gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + // segmentBeforePosition returns the predecessor segment of the position given // by n.children[i], which may or may not contain a child. If no such segment // exists, segmentBeforePosition returns a terminal iterator. @@ -1271,7 +1613,7 @@ func segmentAfterPosition(n *node, i int) Iterator { func zeroValueSlice(slice []Value) { // TODO(jamieliu): check if Go is actually smart enough to optimize a - // ClearValue that assigns nil to a memset here + // ClearValue that assigns nil to a memset here. for i := range slice { Functions{}.ClearValue(&slice[i]) } @@ -1310,7 +1652,15 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) } buf.WriteString(prefix) - buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + if n.hasChildren { + if trackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } } if child := n.children[n.nrSegments]; child != nil { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) @@ -1362,3 +1712,43 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { } return nil } + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error { + havePrev := false + prev := Key(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *Set) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index f2d8462d8..131bf09b9 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -29,10 +29,28 @@ go_template_instance( }, ) +go_template_instance( + name = "gap_set", + out = "gap_set.go", + consts = { + "trackGaps": "1", + }, + package = "segment", + prefix = "gap", + template = "//pkg/segment:generic_set", + types = { + "Key": "int", + "Range": "Range", + "Value": "int", + "Functions": "gapSetFunctions", + }, +) + go_library( name = "segment", testonly = 1, srcs = [ + "gap_set.go", "int_range.go", "int_set.go", "set_functions.go", diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go index 97b16c158..85fa19096 100644 --- a/pkg/segment/test/segment_test.go +++ b/pkg/segment/test/segment_test.go @@ -17,6 +17,7 @@ package segment import ( "fmt" "math/rand" + "reflect" "testing" ) @@ -32,61 +33,65 @@ const ( // valueOffset is the difference between the value and start of test // segments. valueOffset = 100000 + + // intervalLength is the interval used by random gap tests. + intervalLength = 10 ) func shuffle(xs []int) { - for i := range xs { - j := rand.Intn(i + 1) - xs[i], xs[j] = xs[j], xs[i] - } + rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] }) } -func randPermutation(size int) []int { +func randIntervalPermutation(size int) []int { p := make([]int, size) for i := range p { - p[i] = i + p[i] = intervalLength * i } shuffle(p) return p } -// checkSet returns an error if s is incorrectly sorted, does not contain -// exactly expectedSegments segments, or contains a segment for which val != -// key + valueOffset. -func checkSet(s *Set, expectedSegments int) error { - havePrev := false - prev := 0 - nrSegments := 0 - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - next := seg.Start() - if havePrev && prev >= next { - return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) - } - if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want) - } - prev = next - havePrev = true - nrSegments++ - } - if nrSegments != expectedSegments { - return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) +// validate can be passed to Check. +func validate(nr int, r Range, v int) error { + if got, want := v, r.Start+valueOffset; got != want { + return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want) } return nil } -// countSegmentsIn returns the number of segments in s. -func countSegmentsIn(s *Set) int { - var count int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ +// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well +// maintained. +func checkSetMaxGap(s *gapSet) error { + n := s.root + return checkNodeMaxGap(&n) +} + +// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is +// not well maintained. +func checkNodeMaxGap(n *gapnode) error { + var max int + if !n.hasChildren { + max = n.calculateMaxGapLeaf() + } else { + for i := 0; i <= n.nrSegments; i++ { + child := n.children[i] + if err := checkNodeMaxGap(child); err != nil { + return err + } + if temp := child.maxGap.Get(); i == 0 || temp > max { + max = temp + } + } + } + if max != n.maxGap.Get() { + return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap) } - return count + return nil } func TestAddRandom(t *testing.T) { var s Set - order := randPermutation(testSize) + order := rand.Perm(testSize) var nrInsertions int for i, j := range order { if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { @@ -94,12 +99,12 @@ func TestAddRandom(t *testing.T) { break } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -115,7 +120,156 @@ func TestRemoveRandom(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(testSize) + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapAddRandom(t *testing.T) { + var s gapSet + order := rand.Perm(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithRandomInterval(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + nrInsertions := 1 + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapRemoveRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapRemoveHalfRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := randIntervalPermutation(testSize) + order = order[:testSize/2] var nrRemovals int for i, j := range order { seg := s.FindSegment(j) @@ -123,14 +277,19 @@ func TestRemoveRandom(t *testing.T) { t.Errorf("Iteration %d: failed to find segment with key %d", i, j) break } + temprange := seg.Range() s.Remove(seg) nrRemovals++ - if err := checkSet(&s, testSize-nrRemovals); err != nil { + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } } - if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want { + if got, want := s.countSegments(), testSize-nrRemovals; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -140,6 +299,148 @@ func TestRemoveRandom(t *testing.T) { } } +func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + var nrRemovals int + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestNextLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestPrevLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + end := s.LastSegment().End() + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + func TestAddSequentialAdjacent(t *testing.T) { var s Set var nrInsertions int @@ -148,12 +449,12 @@ func TestAddSequentialAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -202,12 +503,12 @@ func TestAddSequentialNonAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -293,7 +594,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -351,7 +652,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -378,7 +679,7 @@ func benchmarkAddSequential(b *testing.B, size int) { } func benchmarkAddRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -416,7 +717,7 @@ func benchmarkFindRandom(b *testing.B, size int) { b.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -470,7 +771,7 @@ func benchmarkAddFindRemoveSequential(b *testing.B, size int) { } func benchmarkAddFindRemoveRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go index bcddb39bb..7cd895cc7 100644 --- a/pkg/segment/test/set_functions.go +++ b/pkg/segment/test/set_functions.go @@ -14,21 +14,16 @@ package segment -// Basic numeric constants that we define because the math package doesn't. -// TODO(nlacasse): These should be Math.MaxInt64/MinInt64? -const ( - maxInt = int(^uint(0) >> 1) - minInt = -maxInt - 1 -) - type setFunctions struct{} -func (setFunctions) MinKey() int { - return minInt +// MinKey returns the minimum key for the set. +func (s setFunctions) MinKey() int { + return -s.MaxKey() - 1 } +// MaxKey returns the maximum key for the set. func (setFunctions) MaxKey() int { - return maxInt + return int(^uint(0) >> 1) } func (setFunctions) ClearValue(*int) {} @@ -40,3 +35,20 @@ func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) { func (setFunctions) Split(_ Range, val int, _ int) (int, int) { return val, val } + +type gapSetFunctions struct { + setFunctions +} + +// MinKey is adjusted to make sure no add overflow would happen in test cases. +// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length(). +// +// Normally Keys should be unsigned to avoid these issues. +func (s gapSetFunctions) MinKey() int { + return s.setFunctions.MinKey() / 2 +} + +// MaxKey returns the maximum key for the set. +func (s gapSetFunctions) MaxKey() int { + return s.setFunctions.MaxKey() / 2 +} diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index 6564fd0c6..dd6f5aba6 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/usage" ) // FrameRefSetFunctions implements segment.Functions for FrameRefSet. @@ -49,3 +50,42 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } + +// IncRefAndAccount adds a reference on the range fr. All newly inserted segments +// are accounted as host page cache memory mappings. +func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { + seg, gap := refs.Find(fr.Start) + for { + switch { + case seg.Ok() && seg.Start() < fr.End: + seg = refs.Isolate(seg, fr) + seg.SetValue(seg.Value() + 1) + seg, gap = seg.NextNonEmpty() + case gap.Ok() && gap.Start() < fr.End: + newRange := gap.Range().Intersect(fr) + usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) + seg, gap = refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() + default: + refs.MergeAdjacent(fr) + return + } + } +} + +// DecRefAndAccount removes a reference on the range fr and untracks segments +// that are removed from memory accounting. +func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { + seg := refs.FindSegment(fr.Start) + + for seg.Ok() && seg.Start() < fr.End { + seg = refs.Isolate(seg, fr) + if old := seg.Value(); old == 1 { + usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) + seg = refs.Remove(seg).NextSegment() + } else { + seg.SetValue(old - 1) + seg = seg.NextSegment() + } + } + refs.MergeAdjacent(fr) +} diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 5ce82b793..67e916525 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -36,7 +36,6 @@ go_library( "gofer.go", "handle.go", "p9file.go", - "pagemath.go", "regular_file.go", "socket.go", "special_file.go", diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index ebf063a58..6295f6b54 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -928,8 +928,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin // so we can't race with Write or another truncate.) d.dataMu.Unlock() if d.size < oldSize { - oldpgend := pageRoundUp(oldSize) - newpgend := pageRoundUp(d.size) + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(d.size) if oldpgend != newpgend { d.mapsMu.Lock() d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 857f7c74e..0d10cf7ac 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -148,9 +148,9 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, err } // Remove touched pages from the cache. - pgstart := pageRoundDown(uint64(offset)) - pgend := pageRoundUp(uint64(offset + src.NumBytes())) - if pgend < pgstart { + pgstart := usermem.PageRoundDown(uint64(offset)) + pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes())) + if !ok { return 0, syserror.EINVAL } mr := memmap.MappableRange{pgstart, pgend} @@ -306,9 +306,10 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) if fillCache { // Read into the cache, then re-enter the loop to read from the // cache. + gapEnd, _ := usermem.PageRoundUp(gapMR.End) reqMR := memmap.MappableRange{ - Start: pageRoundDown(gapMR.Start), - End: pageRoundUp(gapMR.End), + Start: usermem.PageRoundDown(gapMR.Start), + End: gapEnd, } optMR := gap.Range() err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt) @@ -671,7 +672,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab // Constrain translations to d.size (rounded up) to prevent translation to // pages that may be concurrently truncated. - pgend := pageRoundUp(d.size) + pgend, _ := usermem.PageRoundUp(d.size) var beyondEOF bool if required.End > pgend { if required.Start >= pgend { @@ -818,43 +819,15 @@ type dentryPlatformFile struct { // IncRef implements platform.File.IncRef. func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { d.dataMu.Lock() - seg, gap := d.fdRefs.Find(fr.Start) - for { - switch { - case seg.Ok() && seg.Start() < fr.End: - seg = d.fdRefs.Isolate(seg, fr) - seg.SetValue(seg.Value() + 1) - seg, gap = seg.NextNonEmpty() - case gap.Ok() && gap.Start() < fr.End: - newRange := gap.Range().Intersect(fr) - usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) - seg, gap = d.fdRefs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() - default: - d.fdRefs.MergeAdjacent(fr) - d.dataMu.Unlock() - return - } - } + d.fdRefs.IncRefAndAccount(fr) + d.dataMu.Unlock() } // DecRef implements platform.File.DecRef. func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { d.dataMu.Lock() - seg := d.fdRefs.FindSegment(fr.Start) - - for seg.Ok() && seg.Start() < fr.End { - seg = d.fdRefs.Isolate(seg, fr) - if old := seg.Value(); old == 1 { - usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) - seg = d.fdRefs.Remove(seg).NextSegment() - } else { - seg.SetValue(old - 1) - seg = seg.NextSegment() - } - } - d.fdRefs.MergeAdjacent(fr) + d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() - } // MapInternal implements platform.File.MapInternal. diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 39509f703..ca0fe6d2b 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,6 +8,7 @@ go_library( "control.go", "host.go", "ioctl_unsafe.go", + "mmap.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -23,12 +24,15 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/arch", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", + "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 8caf55a1b..18b127521 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -86,15 +86,13 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) i := &inode{ hostFD: hostFD, - seekable: seekable, + ino: fs.NextIno(), isTTY: opts.IsTTY, - canMap: canMap(uint32(fileType)), wouldBlock: wouldBlock(uint32(fileType)), - ino: fs.NextIno(), - // For simplicity, set offset to 0. Technically, we should use the existing - // offset on the host if the file is seekable. - offset: 0, + seekable: seekable, + canMap: canMap(uint32(fileType)), } + i.pf.inode = i // Non-seekable files can't be memory mapped, assert this. if !i.seekable && i.canMap { @@ -117,6 +115,10 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // i.open will take a reference on d. defer d.DecRef() + + // For simplicity, fileDescription.offset is set to 0. Technically, we + // should only set to 0 on files that are not seekable (sockets, pipes, + // etc.), and use the offset from the host fd otherwise when importing. return i.open(ctx, d.VFSDentry(), mnt, flags) } @@ -189,11 +191,15 @@ type inode struct { // This field is initialized at creation time and is immutable. hostFD int - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. + // ino is an inode number unique within this filesystem. // // This field is initialized at creation time and is immutable. - wouldBlock bool + ino uint64 + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool // seekable is false if the host fd points to a file representing a stream, // e.g. a socket or a pipe. Such files are not seekable and can return @@ -202,29 +208,29 @@ type inode struct { // This field is initialized at creation time and is immutable. seekable bool - // isTTY is true if this file represents a TTY. + // wouldBlock is true if the host FD would return EWOULDBLOCK for + // operations that would block. // // This field is initialized at creation time and is immutable. - isTTY bool + wouldBlock bool + + // Event queue for blocking operations. + queue waiter.Queue // canMap specifies whether we allow the file to be memory mapped. // // This field is initialized at creation time and is immutable. canMap bool - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 + // mapsMu protects mappings. + mapsMu sync.Mutex - // offsetMu protects offset. - offsetMu sync.Mutex - - // offset specifies the current file offset. - offset int64 + // If canMap is true, mappings tracks mappings of hostFD into + // memmap.MappingSpaces. + mappings memmap.MappingSet - // Event queue for blocking operations. - queue waiter.Queue + // pf implements platform.File for mappings of hostFD. + pf inodePlatformFile } // CheckPermissions implements kernfs.Inode. @@ -388,6 +394,21 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err } + oldSize := uint64(hostStat.Size) + if s.Size < oldSize { + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(s.Size) + if oldpgend != newpgend { + i.mapsMu.Lock() + i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + i.mapsMu.Unlock() + } + } } if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 { ts := [2]syscall.Timespec{ @@ -464,9 +485,6 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u return vfsfd, nil } - // For simplicity, set offset to 0. Technically, we should - // only set to 0 on files that are not seekable (sockets, pipes, etc.), - // and use the offset from the host fd otherwise. fd := &fileDescription{inode: i} vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { @@ -487,6 +505,13 @@ type fileDescription struct { // // inode is immutable after fileDescription creation. inode *inode + + // offsetMu protects offset. + offsetMu sync.Mutex + + // offset specifies the current file offset. It is only meaningful when + // inode.seekable is true. + offset int64 } // SetStat implements vfs.FileDescriptionImpl. @@ -532,10 +557,10 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. - i.offsetMu.Lock() - n, err := readFromHostFD(ctx, i.hostFD, dst, i.offset, opts.Flags) - i.offset += n - i.offsetMu.Unlock() + f.offsetMu.Lock() + n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags) + f.offset += n + f.offsetMu.Unlock() return n, err } @@ -572,10 +597,10 @@ func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opt } // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file. - i.offsetMu.Lock() - n, err := writeToHostFD(ctx, i.hostFD, src, i.offset, opts.Flags) - i.offset += n - i.offsetMu.Unlock() + f.offsetMu.Lock() + n, err := writeToHostFD(ctx, i.hostFD, src, f.offset, opts.Flags) + f.offset += n + f.offsetMu.Unlock() return n, err } @@ -600,41 +625,41 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return 0, syserror.ESPIPE } - i.offsetMu.Lock() - defer i.offsetMu.Unlock() + f.offsetMu.Lock() + defer f.offsetMu.Unlock() switch whence { case linux.SEEK_SET: if offset < 0 { - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - i.offset = offset + f.offset = offset case linux.SEEK_CUR: - // Check for overflow. Note that underflow cannot occur, since i.offset >= 0. - if offset > math.MaxInt64-i.offset { - return i.offset, syserror.EOVERFLOW + // Check for overflow. Note that underflow cannot occur, since f.offset >= 0. + if offset > math.MaxInt64-f.offset { + return f.offset, syserror.EOVERFLOW } - if i.offset+offset < 0 { - return i.offset, syserror.EINVAL + if f.offset+offset < 0 { + return f.offset, syserror.EINVAL } - i.offset += offset + f.offset += offset case linux.SEEK_END: var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { - return i.offset, err + return f.offset, err } size := s.Size // Check for overflow. Note that underflow cannot occur, since size >= 0. if offset > math.MaxInt64-size { - return i.offset, syserror.EOVERFLOW + return f.offset, syserror.EOVERFLOW } if size+offset < 0 { - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - i.offset = size + offset + f.offset = size + offset case linux.SEEK_DATA, linux.SEEK_HOLE: // Modifying the offset in the host file table should not matter, since @@ -643,16 +668,16 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i // For reading and writing, we always rely on our internal offset. n, err := unix.Seek(i.hostFD, offset, int(whence)) if err != nil { - return i.offset, err + return f.offset, err } - i.offset = n + f.offset = n default: // Invalid whence. - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - return i.offset, nil + return f.offset, nil } // Sync implements FileDescriptionImpl. @@ -666,8 +691,9 @@ func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts if !f.inode.canMap { return syserror.ENODEV } - // TODO(gvisor.dev/issue/1672): Implement ConfigureMMap and Mappable interface. - return syserror.ENODEV + i := f.inode + i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) + return vfs.GenericConfigureMMap(&f.vfsfd, i, opts) } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go new file mode 100644 index 000000000..8545a82f0 --- /dev/null +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" +) + +// inodePlatformFile implements platform.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// +// inodePlatformFile should only be used if inode.canMap is true. +type inodePlatformFile struct { + *inode + + // fdRefsMu protects fdRefs. + fdRefsMu sync.Mutex + + // fdRefs counts references on platform.File offsets. It is used solely for + // memory accounting. + fdRefs fsutil.FrameRefSet + + // fileMapper caches mappings of the host file represented by this inode. + fileMapper fsutil.HostFileMapper + + // fileMapperInitOnce is used to lazily initialize fileMapper. + fileMapperInitOnce sync.Once +} + +// IncRef implements platform.File.IncRef. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) IncRef(fr platform.FileRange) { + i.fdRefsMu.Lock() + i.fdRefs.IncRefAndAccount(fr) + i.fdRefsMu.Unlock() +} + +// DecRef implements platform.File.DecRef. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) DecRef(fr platform.FileRange) { + i.fdRefsMu.Lock() + i.fdRefs.DecRefAndAccount(fr) + i.fdRefsMu.Unlock() +} + +// MapInternal implements platform.File.MapInternal. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { + return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) +} + +// FD implements platform.File.FD. +func (i *inodePlatformFile) FD() int { + return i.hostFD +} + +// AddMapping implements memmap.Mappable.AddMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + i.mapsMu.Lock() + mapped := i.mappings.AddMapping(ms, ar, offset, writable) + for _, r := range mapped { + i.pf.fileMapper.IncRefOn(r) + } + i.mapsMu.Unlock() + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + i.mapsMu.Lock() + unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable) + for _, r := range unmapped { + i.pf.fileMapper.DecRefOn(r) + } + i.mapsMu.Unlock() +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + return i.AddMapping(ctx, ms, dstAR, offset, writable) +} + +// Translate implements memmap.Mappable.Translate. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + mr := optional + return []memmap.Translation{ + { + Source: mr, + File: &i.pf, + Offset: mr.Start, + Perms: usermem.AnyAccess, + }, + }, nil +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) InvalidateUnsavable(ctx context.Context) error { + // We expect the same host fd across save/restore, so all translations + // should be valid. + return nil +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 73591dab7..a036ce53c 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -25,6 +25,7 @@ go_template_instance( out = "vma_set.go", consts = { "minDegree": "8", + "trackGaps": "1", }, imports = { "usermem": "gvisor.dev/gvisor/pkg/usermem", diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 9a14e69e6..16d8207e9 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -195,7 +195,7 @@ func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange { // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextGap() { + for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift up to match the alignment? if offset := uint64(gr.Start) % alignment; offset != 0 { @@ -214,7 +214,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevGap() { + for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift down to match the alignment? start := gr.End - usermem.Addr(length) diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 444a83913..a6345010d 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -38,6 +38,12 @@ func SaveVRegs(*byte) // LoadVRegs loads V0-V31 registers. func LoadVRegs(*byte) +// GetTLS returns the value of TPIDR_EL0 register. +func GetTLS() (value uint64) + +// SetTLS writes the TPIDR_EL0 value. +func SetTLS(value uint64) + // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 0e6a6235b..b63e14b41 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,16 @@ #include "funcdata.h" #include "textflag.h" +TEXT ·GetTLS(SB),NOSPLIT,$0-8 + MRS TPIDR_EL0, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·SetTLS(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + MSR R1, TPIDR_EL0 + RET + TEXT ·CPACREL1(SB),NOSPLIT,$0-8 WORD $0xd5381041 // MRS CPACR_EL1, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index b49433326..c11e82c10 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -555,7 +555,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if uint64(src.NumBytes()) != srcs.NumBytes() { return 0, nil } - if srcs.IsEmpty() { + if srcs.IsEmpty() && len(controlBuf) == 0 { return 0, nil } diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 921af9d63..2b1350135 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -47,6 +47,7 @@ go_library( "state.go", "stats.go", ], + marshal = False, stateify = False, visibility = ["//:sandbox"], deps = [ diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 0e35d7d17..d0d77e19c 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -39,6 +39,8 @@ go_library( "seqcount.go", "sync.go", ], + marshal = False, + stateify = False, ) go_test( diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 29454c4b9..4c6f808e5 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -66,6 +66,14 @@ const ( TCPOptionSACK = 5 ) +// Option Lengths. +const ( + TCPOptionMSSLength = 4 + TCPOptionTSLength = 10 + TCPOptionWSLength = 3 + TCPOptionSackPermittedLength = 2 +) + // TCPFields contains the fields of a TCP packet. It is used to describe the // fields of a packet that needs to be encoded. type TCPFields struct { @@ -494,14 +502,11 @@ func ParseTCPOptions(b []byte) TCPOptions { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeMSSOption(mss uint32, b []byte) int { - // mssOptionSize is the number of bytes in a valid MSS option. - const mssOptionSize = 4 - - if len(b) < mssOptionSize { + if len(b) < TCPOptionMSSLength { return 0 } - b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) - return mssOptionSize + b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss) + return TCPOptionMSSLength } // EncodeWSOption encodes the WS TCP option with the WS value in the @@ -509,10 +514,10 @@ func EncodeMSSOption(mss uint32, b []byte) int { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeWSOption(ws int, b []byte) int { - if len(b) < 3 { + if len(b) < TCPOptionWSLength { return 0 } - b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) + b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws) return int(b[1]) } @@ -521,10 +526,10 @@ func EncodeWSOption(ws int, b []byte) int { // just returns without encoding anything. It returns the number of bytes // written to the provided buffer. func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { - if len(b) < 10 { + if len(b) < TCPOptionTSLength { return 0 } - b[0], b[1] = TCPOptionTS, 10 + b[0], b[1] = TCPOptionTS, TCPOptionTSLength binary.BigEndian.PutUint32(b[2:], tsVal) binary.BigEndian.PutUint32(b[6:], tsEcr) return int(b[1]) @@ -535,11 +540,11 @@ func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { // encoding anything. It returns the number of bytes written to the provided // buffer. func EncodeSACKPermittedOption(b []byte) int { - if len(b) < 2 { + if len(b) < TCPOptionSackPermittedLength { return 0 } - b[0], b[1] = TCPOptionSACKPermitted, 2 + b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength return int(b[1]) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 45e930ad8..b7b227328 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -110,6 +110,71 @@ var ( ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ) +var messageToError map[string]*Error + +var populate sync.Once + +// StringToError converts an error message to the error. +func StringToError(s string) *Error { + populate.Do(func() { + var errors = []*Error{ + ErrUnknownProtocol, + ErrUnknownNICID, + ErrUnknownDevice, + ErrUnknownProtocolOption, + ErrDuplicateNICID, + ErrDuplicateAddress, + ErrNoRoute, + ErrBadLinkEndpoint, + ErrAlreadyBound, + ErrInvalidEndpointState, + ErrAlreadyConnecting, + ErrAlreadyConnected, + ErrNoPortAvailable, + ErrPortInUse, + ErrBadLocalAddress, + ErrClosedForSend, + ErrClosedForReceive, + ErrWouldBlock, + ErrConnectionRefused, + ErrTimeout, + ErrAborted, + ErrConnectStarted, + ErrDestinationRequired, + ErrNotSupported, + ErrQueueSizeNotSupported, + ErrNotConnected, + ErrConnectionReset, + ErrConnectionAborted, + ErrNoSuchFile, + ErrInvalidOptionValue, + ErrNoLinkAddress, + ErrBadAddress, + ErrNetworkUnreachable, + ErrMessageTooLong, + ErrNoBufferSpace, + ErrBroadcastDisabled, + ErrNotPermitted, + ErrAddressFamilyNotSupported, + } + + messageToError = make(map[string]*Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} + // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 8b7562396..fc43c11e2 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -314,7 +314,7 @@ func (e *endpoint) loadLastError(s string) { return } - e.lastError = loadError(s) + e.lastError = tcpip.StringToError(s) } // saveHardError is invoked by stateify. @@ -332,71 +332,7 @@ func (e *EndpointInfo) loadHardError(s string) { return } - e.HardError = loadError(s) -} - -var messageToError map[string]*tcpip.Error - -var populate sync.Once - -func loadError(s string) *tcpip.Error { - populate.Do(func() { - var errors = []*tcpip.Error{ - tcpip.ErrUnknownProtocol, - tcpip.ErrUnknownNICID, - tcpip.ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID, - tcpip.ErrDuplicateAddress, - tcpip.ErrNoRoute, - tcpip.ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound, - tcpip.ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected, - tcpip.ErrNoPortAvailable, - tcpip.ErrPortInUse, - tcpip.ErrBadLocalAddress, - tcpip.ErrClosedForSend, - tcpip.ErrClosedForReceive, - tcpip.ErrWouldBlock, - tcpip.ErrConnectionRefused, - tcpip.ErrTimeout, - tcpip.ErrAborted, - tcpip.ErrConnectStarted, - tcpip.ErrDestinationRequired, - tcpip.ErrNotSupported, - tcpip.ErrQueueSizeNotSupported, - tcpip.ErrNotConnected, - tcpip.ErrConnectionReset, - tcpip.ErrConnectionAborted, - tcpip.ErrNoSuchFile, - tcpip.ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress, - tcpip.ErrBadAddress, - tcpip.ErrNetworkUnreachable, - tcpip.ErrMessageTooLong, - tcpip.ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled, - tcpip.ErrNotPermitted, - tcpip.ErrAddressFamilyNotSupported, - } - - messageToError = make(map[string]*tcpip.Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e + e.HardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..647b2067a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -106,6 +106,9 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID @@ -188,6 +191,15 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + err := e.lastError + e.lastError = nil + return err +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -243,6 +255,10 @@ func (e *endpoint) IPTables() (stack.IPTables, error) { // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + e.rcvMu.Lock() if e.rcvList.Empty() { @@ -382,6 +398,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return 0, nil, err + } + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -853,6 +873,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: + return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -1316,6 +1337,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { + if typ == stack.ControlPortUnreachable { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state == StateConnected { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + e.lastError = tcpip.ErrConnectionRefused + } + } } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 466bd9381..851e6b635 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,6 +37,24 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = tcpip.StringToError(s) +} + // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). diff --git a/pkg/usermem/addr.go b/pkg/usermem/addr.go index e79210804..c4100481e 100644 --- a/pkg/usermem/addr.go +++ b/pkg/usermem/addr.go @@ -106,3 +106,20 @@ func (ar AddrRange) IsPageAligned() bool { func (ar AddrRange) String() string { return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End) } + +// PageRoundDown/Up are equivalent to Addr.RoundDown/Up, but without the +// potentially truncating conversion from uint64 to Addr. This is necessary +// because there is no way to define generic "PageRoundDown/Up" functions in Go. + +// PageRoundDown returns x rounded down to the nearest page boundary. +func PageRoundDown(x uint64) uint64 { + return x &^ (PageSize - 1) +} + +// PageRoundUp returns x rounded up to the nearest page boundary. +// ok is true iff rounding up did not wrap around. +func PageRoundUp(x uint64) (addr uint64, ok bool) { + addr = PageRoundDown(x + PageSize - 1) + ok = addr >= x + return +} diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index fa40ee509..19c8b0db6 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -19,6 +19,7 @@ package cgroup import ( "bufio" "context" + "errors" "fmt" "io/ioutil" "os" @@ -38,21 +39,23 @@ const ( cgroupRoot = "/sys/fs/cgroup" ) -var controllers = map[string]controller{ - "blkio": &blockIO{}, - "cpu": &cpu{}, - "cpuset": &cpuSet{}, - "memory": &memory{}, - "net_cls": &networkClass{}, - "net_prio": &networkPrio{}, - "pids": &pids{}, +var controllers = map[string]config{ + "blkio": config{ctrlr: &blockIO{}}, + "cpu": config{ctrlr: &cpu{}}, + "cpuset": config{ctrlr: &cpuSet{}}, + "memory": config{ctrlr: &memory{}}, + "net_cls": config{ctrlr: &networkClass{}}, + "net_prio": config{ctrlr: &networkPrio{}}, + "pids": config{ctrlr: &pids{}}, // These controllers either don't have anything in the OCI spec or is // irrelevant for a sandbox. - "devices": &noop{}, - "freezer": &noop{}, - "perf_event": &noop{}, - "systemd": &noop{}, + "devices": config{ctrlr: &noop{}}, + "freezer": config{ctrlr: &noop{}}, + "hugetlb": config{ctrlr: &noop{}, optional: true}, + "perf_event": config{ctrlr: &noop{}}, + "rdma": config{ctrlr: &noop{}, optional: true}, + "systemd": config{ctrlr: &noop{}}, } func setOptionalValueInt(path, name string, val *int64) error { @@ -196,8 +199,9 @@ func LoadPaths(pid string) (map[string]string, error) { return paths, nil } -// Cgroup represents a group inside all controllers. For example: Name='/foo/bar' -// maps to /sys/fs/cgroup/<controller>/foo/bar on all controllers. +// Cgroup represents a group inside all controllers. For example: +// Name='/foo/bar' maps to /sys/fs/cgroup/<controller>/foo/bar on +// all controllers. type Cgroup struct { Name string `json:"name"` Parents map[string]string `json:"parents"` @@ -245,13 +249,17 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error { clean := specutils.MakeCleanup(func() { _ = c.Uninstall() }) defer clean.Clean() - for key, ctrl := range controllers { + for key, cfg := range controllers { path := c.makePath(key) if err := os.MkdirAll(path, 0755); err != nil { + if cfg.optional && errors.Is(err, syscall.EROFS) { + log.Infof("Skipping cgroup %q", key) + continue + } return err } if res != nil { - if err := ctrl.set(res, path); err != nil { + if err := cfg.ctrlr.set(res, path); err != nil { return err } } @@ -321,10 +329,13 @@ func (c *Cgroup) Join() (func(), error) { } // Now join the cgroups. - for key := range controllers { + for key, cfg := range controllers { path := c.makePath(key) log.Debugf("Joining cgroup %q", path) if err := setValue(path, "cgroup.procs", "0"); err != nil { + if cfg.optional && os.IsNotExist(err) { + continue + } return undo, err } } @@ -375,6 +386,11 @@ func (c *Cgroup) makePath(controllerName string) string { return filepath.Join(cgroupRoot, controllerName, path) } +type config struct { + ctrlr controller + optional bool +} + type controller interface { set(*specs.LinuxResources, string) error } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index d4fcf31fa..c4ffda17e 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -31,8 +31,6 @@ packetimpact_go_test( packetimpact_go_test( name = "udp_icmp_error_propagation", srcs = ["udp_icmp_error_propagation_test.go"], - # TODO(b/153926291): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/header", @@ -136,6 +134,19 @@ packetimpact_go_test( ) packetimpact_go_test( + name = "tcp_paws_mechanism", + srcs = ["tcp_paws_mechanism_test.go"], + # TODO(b/156682000): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( name = "tcp_user_timeout", srcs = ["tcp_user_timeout_test.go"], deps = [ diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go new file mode 100644 index 000000000..0a668adcf --- /dev/null +++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_paws_mechanism_test + +import ( + "encoding/hex" + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + tb.RegisterFlags(flag.CommandLine) +} + +func TestPAWSMechanism(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + options := make([]byte, header.TCPOptionTSLength) + header.EncodeTSOption(currentTS(), 0, options) + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn), Options: options}) + synAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("didn't get synack during handshake: %s", err) + } + parsedSynOpts := header.ParseSynOptions(synAck.Options, true) + if !parsedSynOpts.TS { + t.Fatalf("expected TSOpt from DUT, options we got:\n%s", hex.Dump(synAck.Options)) + } + tsecr := parsedSynOpts.TSVal + header.EncodeTSOption(currentTS(), tsecr, options) + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), Options: options}) + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + sampleData := []byte("Sample Data") + sentTSVal := currentTS() + header.EncodeTSOption(sentTSVal, tsecr, options) + // 3ms here is chosen arbitrarily to make sure we have increasing timestamps + // every time we send one, it should not cause any flakiness because timestamps + // only need to be non-decreasing. + time.Sleep(3 * time.Millisecond) + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), Options: options}, &tb.Payload{Bytes: sampleData}) + + gotTCP, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK but got none: %s", err) + } + + parsedOpts := header.ParseTCPOptions(gotTCP.Options) + if !parsedOpts.TS { + t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options)) + } + if parsedOpts.TSVal < tsecr { + t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr) + } + if parsedOpts.TSEcr != sentTSVal { + t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal) + } + tsecr = parsedOpts.TSVal + lastAckNum := gotTCP.AckNum + + badTSVal := sentTSVal - 100 + header.EncodeTSOption(badTSVal, tsecr, options) + // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness + // due to the exact same reasoning discussed above. + time.Sleep(3 * time.Millisecond) + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), Options: options}, &tb.Payload{Bytes: sampleData}) + + gotTCP, err = conn.Expect(tb.TCP{AckNum: lastAckNum, Flags: tb.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err) + } + parsedOpts = header.ParseTCPOptions(gotTCP.Options) + if !parsedOpts.TS { + t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options)) + } + if parsedOpts.TSVal < tsecr { + t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr) + } + if parsedOpts.TSEcr != sentTSVal { + t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal) + } +} + +func currentTS() uint32 { + return uint32(time.Now().UnixNano() / 1e6) +} diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index dd981a278..e397d5f57 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -267,8 +267,19 @@ int TestSIGPROFFairness(absl::Duration sleep) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { - // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. - SKIP_IF(GvisorPlatform() == Platform::kKVM); + // On the KVM and ptrace platforms, switches between sentry and application + // context are sometimes extremely slow, causing the itimer to send SIGPROF to + // a thread that either already has one pending or has had SIGPROF delivered, + // but hasn't handled it yet (and thus therefore still has SIGPROF masked). In + // either case, since itimer signals are group-directed, signal sending falls + // back to notifying the thread group leader. ItimerSignalTest() fails if "too + // many" signals are delivered to the thread group leader, so these tests are + // flaky on these platforms. + // + // TODO(b/143247272): Clarify why context switches are so slow on KVM. + const auto gvisor_platform = GvisorPlatform(); + SKIP_IF(gvisor_platform == Platform::kKVM || + gvisor_platform == Platform::kPtrace); pid_t child; int execve_errno; @@ -291,8 +302,10 @@ TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) { - // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. - SKIP_IF(GvisorPlatform() == Platform::kKVM); + // See comment in DeliversSIGPROFToThreadsRoughlyFairlyActive. + const auto gvisor_platform = GvisorPlatform(); + SKIP_IF(gvisor_platform == Platform::kKVM || + gvisor_platform == Platform::kPtrace); pid_t child; int execve_errno; diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 740c7986d..42521efef 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -17,6 +17,7 @@ #include <arpa/inet.h> #include <fcntl.h> #include <netinet/in.h> +#include <poll.h> #include <sys/ioctl.h> #include <sys/socket.h> #include <sys/types.h> @@ -673,6 +674,11 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { char buf[3]; // Send zero length packet from s_ to t_. ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {t_, POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + // Receive the packet. char received[3]; EXPECT_THAT(read(t_, received, sizeof(received)), @@ -698,6 +704,11 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { char buf[3]; // Send zero length packet from s_ to t_. ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {t_, POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + // Receive the packet. char received[3]; EXPECT_THAT(read(t_, received, sizeof(received)), @@ -859,6 +870,10 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + struct pollfd pfd = {s_, POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + // We should get the data even though read has been shutdown. EXPECT_THAT(recv(s_, received, 2, 0), SyscallSucceedsWithValue(2)); @@ -1112,6 +1127,10 @@ TEST_P(UdpSocketTest, FIONREADWriteShutdown) { ASSERT_THAT(send(s_, str, sizeof(str), 0), SyscallSucceedsWithValue(sizeof(str))); + struct pollfd pfd = {s_, POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + n = -1; EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, sizeof(str)); @@ -1123,6 +1142,8 @@ TEST_P(UdpSocketTest, FIONREADWriteShutdown) { EXPECT_EQ(n, sizeof(str)); } +// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the +// corresponding macro and become `0x541B`. TEST_P(UdpSocketTest, Fionread) { // Bind s_ to loopback:TestPort. ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1138,10 +1159,14 @@ TEST_P(UdpSocketTest, Fionread) { char buf[3 * psize]; RandomizeBuffer(buf, sizeof(buf)); + struct pollfd pfd = {s_, POLLIN, 0}; for (int i = 0; i < 3; ++i) { ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), SyscallSucceedsWithValue(psize)); + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + // Check that regardless of how many packets are in the queue, the size // reported is that of a single packet. n = -1; @@ -1165,10 +1190,18 @@ TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { char buf[3 * psize]; RandomizeBuffer(buf, sizeof(buf)); + struct pollfd pfd = {s_, POLLIN, 0}; for (int i = 0; i < 3; ++i) { ASSERT_THAT(sendto(t_, buf + i * psize, 0, 0, addr_[0], addrlen_), SyscallSucceedsWithValue(0)); + // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet + // socket does not cause a poll event to be triggered. + if (!IsRunningWithHostinet()) { + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + } + // Check that regardless of how many packets are in the queue, the size // reported is that of a single packet. n = -1; @@ -1235,6 +1268,10 @@ TEST_P(UdpSocketTest, SoTimestamp) { // Send zero length packet from t_ to s_. ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + struct pollfd pfd = {s_, POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; msghdr msg; memset(&msg, 0, sizeof(msg)); @@ -1278,6 +1315,10 @@ TEST_P(UdpSocketTest, TimestampIoctl) { ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + struct pollfd pfd = {s_, POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + // There should be no control messages. char recv_buf[sizeof(buf)]; ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); @@ -1315,6 +1356,10 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { SyscallSucceedsWithValue(sizeof(buf))); ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + struct pollfd pfd = {s_, POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + // There should be no control messages. char recv_buf[sizeof(buf)]; ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); @@ -1330,6 +1375,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { SyscallSucceeds()); ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + // There should be a message for SO_TIMESTAMP. char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; msghdr msg = {}; diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go index e9cc2c753..0860ca9db 100644 --- a/tools/go_generics/generics.go +++ b/tools/go_generics/generics.go @@ -223,7 +223,9 @@ func main() { } else { switch kind { case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction: - ident.Name = *prefix + ident.Name + *suffix + if ident.Name != "_" { + ident.Name = *prefix + ident.Name + *suffix + } case globals.KindTag: // Modify the state tag appropriately. if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil { diff --git a/website/cmd/syscalldocs/main.go b/website/cmd/syscalldocs/main.go index 62d293a05..327537214 100644 --- a/website/cmd/syscalldocs/main.go +++ b/website/cmd/syscalldocs/main.go @@ -46,7 +46,7 @@ type SyscallDoc struct { } var mdTemplate = template.Must(template.New("out").Parse(`--- -title: {{.OS}}/{{.Arch}} +title: {{.Title}} description: Syscall Compatibility Reference Documentation for {{.OS}}/{{.Arch}} layout: docs category: Compatibility @@ -134,6 +134,7 @@ func main() { weight += 10 data := struct { + Title string OS string Arch string Weight int @@ -149,7 +150,8 @@ func main() { URLs []string } }{ - OS: strings.Title(osName), + Title: strings.Title(osName) + "/" + archName, + OS: osName, Arch: archName, Weight: weight, Total: 0, |