diff options
Diffstat (limited to 'pkg')
200 files changed, 6315 insertions, 2141 deletions
diff --git a/pkg/abi/linux/ip.go b/pkg/abi/linux/ip.go index 31e56ffa6..ef6d1093e 100644 --- a/pkg/abi/linux/ip.go +++ b/pkg/abi/linux/ip.go @@ -92,6 +92,16 @@ const ( IP_UNICAST_IF = 50 ) +// IP_MTU_DISCOVER values from uapi/linux/in.h +const ( + IP_PMTUDISC_DONT = 0 + IP_PMTUDISC_WANT = 1 + IP_PMTUDISC_DO = 2 + IP_PMTUDISC_PROBE = 3 + IP_PMTUDISC_INTERFACE = 4 + IP_PMTUDISC_OMIT = 5 +) + // Socket options from uapi/linux/in6.h const ( IPV6_ADDRFORM = 1 diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go index 0e5b86344..b789e56e9 100644 --- a/pkg/buffer/safemem.go +++ b/pkg/buffer/safemem.go @@ -28,12 +28,11 @@ func (b *buffer) ReadBlock() safemem.Block { return safemem.BlockFromSafeSlice(b.ReadSlice()) } -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. -// -// This will advance the write index. -func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - need := int(srcs.NumBytes()) - if need == 0 { +// WriteFromSafememReader writes up to count bytes from r to v and advances the +// write index by the number of bytes written. It calls r.ReadToBlocks() at +// most once. +func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) { + if count == 0 { return 0, nil } @@ -50,32 +49,33 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { } // Does the last block have sufficient capacity alone? - if l := firstBuf.WriteSize(); l >= need { - dst = safemem.BlockSeqOf(firstBuf.WriteBlock()) + if l := uint64(firstBuf.WriteSize()); l >= count { + dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count)) } else { // Append blocks until sufficient. - need -= l + count -= l blocks = append(blocks, firstBuf.WriteBlock()) - for need > 0 { + for count > 0 { emptyBuf := bufferPool.Get().(*buffer) v.data.PushBack(emptyBuf) - need -= emptyBuf.WriteSize() - blocks = append(blocks, emptyBuf.WriteBlock()) + block := emptyBuf.WriteBlock().TakeFirst64(count) + count -= uint64(block.Len()) + blocks = append(blocks, block) } dst = safemem.BlockSeqFromSlice(blocks) } - // Perform the copy. - n, err := safemem.CopySeq(dst, srcs) + // Perform I/O. + n, err := r.ReadToBlocks(dst) v.size += int64(n) // Update all indices. - for left := int(n); left > 0; firstBuf = firstBuf.Next() { - if l := firstBuf.WriteSize(); left >= l { + for left := n; left > 0; firstBuf = firstBuf.Next() { + if l := firstBuf.WriteSize(); left >= uint64(l) { firstBuf.WriteMove(l) // Whole block. - left -= l + left -= uint64(l) } else { - firstBuf.WriteMove(left) // Partial block. + firstBuf.WriteMove(int(left)) // Partial block. left = 0 } } @@ -83,14 +83,16 @@ func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } -// ReadToBlocks implements safemem.Reader.ReadToBlocks. -// -// This will not advance the read index; the caller should follow -// this call with a call to TrimFront in order to remove the read -// data from the buffer. This is done to support pipe sematics. -func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - need := int(dsts.NumBytes()) - if need == 0 { +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the +// write index by the number of bytes written. +func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes()) +} + +// ReadToSafememWriter reads up to count bytes from v to w. It does not advance +// the read index. It calls w.WriteFromBlocks() at most once. +func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) { + if count == 0 { return 0, nil } @@ -105,25 +107,27 @@ func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { } // Is all the data in a single block? - if l := firstBuf.ReadSize(); l >= need { - src = safemem.BlockSeqOf(firstBuf.ReadBlock()) + if l := uint64(firstBuf.ReadSize()); l >= count { + src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count)) } else { // Build a list of all the buffers. - need -= l + count -= l blocks = append(blocks, firstBuf.ReadBlock()) - for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() { - need -= buf.ReadSize() - blocks = append(blocks, buf.ReadBlock()) + for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() { + block := buf.ReadBlock().TakeFirst64(count) + count -= uint64(block.Len()) + blocks = append(blocks, block) } src = safemem.BlockSeqFromSlice(blocks) } - // Perform the copy. - n, err := safemem.CopySeq(dsts, src) - - // See above: we would normally advance the read index here, but we - // don't do that in order to support pipe semantics. We rely on a - // separate call to TrimFront() in this case. + // Perform I/O. As documented, we don't advance the read index. + return w.WriteFromBlocks(src) +} - return n, err +// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the +// read index by the number of bytes read, such that it's only safe to call if +// the caller guarantees that ReadToBlocks will only be called once. +func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes()) } diff --git a/pkg/cleanup/BUILD b/pkg/cleanup/BUILD new file mode 100644 index 000000000..5c34b9872 --- /dev/null +++ b/pkg/cleanup/BUILD @@ -0,0 +1,17 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "cleanup", + srcs = ["cleanup.go"], + visibility = ["//:sandbox"], + deps = [ + ], +) + +go_test( + name = "cleanup_test", + srcs = ["cleanup_test.go"], + library = ":cleanup", +) diff --git a/pkg/cleanup/cleanup.go b/pkg/cleanup/cleanup.go new file mode 100644 index 000000000..14a05f076 --- /dev/null +++ b/pkg/cleanup/cleanup.go @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cleanup provides utilities to clean "stuff" on defers. +package cleanup + +// Cleanup allows defers to be aborted when cleanup needs to happen +// conditionally. Usage: +// cu := cleanup.Make(func() { f.Close() }) +// defer cu.Clean() // failure before release is called will close the file. +// ... +// cu.Add(func() { f2.Close() }) // Adds another cleanup function +// ... +// cu.Release() // on success, aborts closing the file. +// return f +type Cleanup struct { + cleaners []func() +} + +// Make creates a new Cleanup object. +func Make(f func()) Cleanup { + return Cleanup{cleaners: []func(){f}} +} + +// Add adds a new function to be called on Clean(). +func (c *Cleanup) Add(f func()) { + c.cleaners = append(c.cleaners, f) +} + +// Clean calls all cleanup functions in reverse order. +func (c *Cleanup) Clean() { + clean(c.cleaners) + c.cleaners = nil +} + +// Release releases the cleanup from its duties, i.e. cleanup functions are not +// called after this point. Returns a function that calls all registered +// functions in case the caller has use for them. +func (c *Cleanup) Release() func() { + old := c.cleaners + c.cleaners = nil + return func() { clean(old) } +} + +func clean(cleaners []func()) { + for i := len(cleaners) - 1; i >= 0; i-- { + cleaners[i]() + } +} diff --git a/pkg/cleanup/cleanup_test.go b/pkg/cleanup/cleanup_test.go new file mode 100644 index 000000000..ab3d9ed95 --- /dev/null +++ b/pkg/cleanup/cleanup_test.go @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cleanup + +import "testing" + +func testCleanupHelper(clean, cleanAdd *bool, release bool) func() { + cu := Make(func() { + *clean = true + }) + cu.Add(func() { + *cleanAdd = true + }) + defer cu.Clean() + if release { + return cu.Release() + } + return nil +} + +func TestCleanup(t *testing.T) { + clean := false + cleanAdd := false + testCleanupHelper(&clean, &cleanAdd, false) + if !clean { + t.Fatalf("cleanup function was not called.") + } + if !cleanAdd { + t.Fatalf("added cleanup function was not called.") + } +} + +func TestRelease(t *testing.T) { + clean := false + cleanAdd := false + cleaner := testCleanupHelper(&clean, &cleanAdd, true) + + // Check that clean was not called after release. + if clean { + t.Fatalf("cleanup function was called.") + } + if cleanAdd { + t.Fatalf("added cleanup function was called.") + } + + // Call the cleaner function and check that both cleanup functions are called. + cleaner() + if !clean { + t.Fatalf("cleanup function was not called.") + } + if !cleanAdd { + t.Fatalf("added cleanup function was not called.") + } +} diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index ea8d2422c..7a82631c5 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "goid.go", "goid_amd64.s", + "goid_arm64.s", "goid_race.go", "goid_unsafe.go", ], diff --git a/pkg/goid/goid_arm64.s b/pkg/goid/goid_arm64.s new file mode 100644 index 000000000..a7465b75d --- /dev/null +++ b/pkg/goid/goid_arm64.s @@ -0,0 +1,21 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// func getg() *g +TEXT ·getg(SB),NOSPLIT,$0-8 + MOVD g, R0 // g + MOVD R0, ret+0(FP) + RET diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index 41bf104d0..f84d03700 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -5,6 +5,8 @@ package(licenses = ["notice"]) go_library( name = "linewriter", srcs = ["linewriter.go"], + marshal = False, + stateify = False, visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index a7c8f7bef..3ed6aba5c 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -10,6 +10,8 @@ go_library( "json_k8s.go", "log.go", ], + marshal = False, + stateify = False, visibility = [ "//visibility:public", ], diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 38cea9be3..7c622e5d7 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.15 +// +build !go1.16 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index 4f4b70fef..48ebb5fd1 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.15 +// +build !go1.16 #include "textflag.h" diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD index 1b487b887..f57ccc170 100644 --- a/pkg/segment/BUILD +++ b/pkg/segment/BUILD @@ -21,6 +21,8 @@ go_template( ], opt_consts = [ "minDegree", + # trackGaps must either be 0 or 1. + "trackGaps", ], types = [ "Key", diff --git a/pkg/segment/set.go b/pkg/segment/set.go index 03e4f258f..1a17ad9cb 100644 --- a/pkg/segment/set.go +++ b/pkg/segment/set.go @@ -36,6 +36,34 @@ type Range interface{} // Value is a required type parameter. type Value interface{} +// trackGaps is an optional parameter. +// +// If trackGaps is 1, the Set will track maximum gap size recursively, +// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this +// case, Key must be an unsigned integer. +// +// trackGaps must be 0 or 1. +const trackGaps = 0 + +var _ = uint8(trackGaps << 7) // Will fail if not zero or one. + +// dynamicGap is a type that disappears if trackGaps is 0. +type dynamicGap [trackGaps]Key + +// Get returns the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Get() Key { + return d[:][0] +} + +// Set sets the value of the gap. +// +// Precondition: trackGaps must be non-zero. +func (d *dynamicGap) Set(v Key) { + d[:][0] = v +} + // Functions is a required type parameter that must be a struct implementing // the methods defined by Functions. type Functions interface { @@ -327,8 +355,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if prev.Ok() && prev.End() == r.Start { if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() prev.SetEndUnchecked(r.End) prev.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } if next.Ok() && next.Start() == r.End { val = mval if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok { @@ -342,11 +374,16 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator { } if next.Ok() && next.Start() == r.End { if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok { + shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get() next.SetStartUnchecked(r.Start) next.SetValue(mval) + if shrinkMaxGap { + gap.node.updateMaxGapLeaf() + } return next } } + // InsertWithoutMergingUnchecked will maintain maxGap if necessary. return s.InsertWithoutMergingUnchecked(gap, r, val) } @@ -373,11 +410,15 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { gap = gap.node.rebalanceBeforeInsert(gap) + splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments]) copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments]) gap.node.keys[gap.index] = r gap.node.values[gap.index] = val gap.node.nrSegments++ + if splitMaxGap { + gap.node.updateMaxGapLeaf() + } return Iterator{gap.node, gap.index} } @@ -399,12 +440,23 @@ func (s *Set) Remove(seg Iterator) GapIterator { // overlap. seg.SetRangeUnchecked(victim.Range()) seg.SetValue(victim.Value()) + // Need to update the nextAdjacentNode's maxGap because the gap in between + // must have been modified by updating seg.Range() to victim.Range(). + // seg.NextSegment() must exist since the last segment can't be in a + // non-leaf node. + nextAdjacentNode := seg.NextSegment().node + if trackGaps != 0 { + nextAdjacentNode.updateMaxGapLeaf() + } return s.Remove(victim).NextGap() } copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments]) copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments]) Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1]) seg.node.nrSegments-- + if trackGaps != 0 { + seg.node.updateMaxGapLeaf() + } return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index}) } @@ -455,6 +507,7 @@ func (s *Set) MergeUnchecked(first, second Iterator) Iterator { // overlaps second. first.SetEndUnchecked(second.End()) first.SetValue(mval) + // Remove will handle the maxGap update if necessary. return s.Remove(second).PrevSegment() } } @@ -631,6 +684,12 @@ type node struct { // than "isLeaf" because false must be the correct value for an empty root. hasChildren bool + // The longest gap within this node. If the node is a leaf, it's simply the + // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys + // including the 0th and nrSegments-th gap possibly shared with its upper-level + // nodes; if it's a non-leaf node, it's the max of all children's maxGap. + maxGap dynamicGap + // Nodes store keys and values in separate arrays to maximize locality in // the common case (scanning keys for lookup). keys [maxDegree - 1]Range @@ -676,12 +735,12 @@ func (n *node) nextSibling() *node { // required for insertion, and returns an updated iterator to the position // represented by gap. func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { - if n.parent != nil { - gap = n.parent.rebalanceBeforeInsert(gap) - } if n.nrSegments < maxDegree-1 { return gap } + if n.parent != nil { + gap = n.parent.rebalanceBeforeInsert(gap) + } if n.parent == nil { // n is root. Move all segments before and after n's median segment // into new child nodes adjacent to the median segment, which is now @@ -719,6 +778,13 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { n.hasChildren = true n.children[0] = left n.children[1] = right + // In this case, n's maxGap won't violated as it's still the root, + // but the left and right children should be updated locally as they + // are newly split from n. + if trackGaps != 0 { + left.updateMaxGapLocal() + right.updateMaxGapLocal() + } if gap.node != n { return gap } @@ -758,6 +824,12 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator { } } n.nrSegments = minDegree - 1 + // MaxGap of n's parent is not violated because the segments within is not changed. + // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } // gap.node can't be n.parent because gaps are always in leaf nodes. if gap.node != n { return gap @@ -821,6 +893,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling && gap.index == sibling.nrSegments { return GapIterator{n, 0} } @@ -849,6 +927,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } n.nrSegments++ sibling.nrSegments-- + // n's parent's maxGap does not need to be updated as its content is unmodified. + // n and its sibling must be updated with (new) maxGap because of the shift of keys. + if trackGaps != 0 { + n.updateMaxGapLocal() + sibling.updateMaxGapLocal() + } if gap.node == sibling { if gap.index == 0 { return GapIterator{n, n.nrSegments} @@ -886,6 +970,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { p.children[0] = nil p.children[1] = nil } + // No need to update maxGap of p as its content is not changed. if gap.node == left { return GapIterator{p, gap.index} } @@ -932,11 +1017,152 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator { } p.children[p.nrSegments] = nil p.nrSegments-- + // Update maxGap of left locally, no need to change p and right because + // p's contents is not changed and right is already invalid. + if trackGaps != 0 { + left.updateMaxGapLocal() + } // This process robs p of one segment, so recurse into rebalancing p. n = p } } +// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no +// necessary update. +// +// Preconditions: n must be a leaf node, trackGaps must be 1. +func (n *node) updateMaxGapLeaf() { + if n.hasChildren { + panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n)) + } + max := n.calculateMaxGapLeaf() + if max == n.maxGap.Get() { + // If new max equals the old maxGap, no update is needed. + return + } + oldMax := n.maxGap.Get() + n.maxGap.Set(max) + if max > oldMax { + // Grow ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() >= max { + // p and its ancestors already contain an equal or larger gap. + break + } + // Only if new maxGap is larger than parent's + // old maxGap, propagate this update to parent. + p.maxGap.Set(max) + } + return + } + // Shrink ancestor maxGaps. + for p := n.parent; p != nil; p = p.parent { + if p.maxGap.Get() > oldMax { + // p and its ancestors still contain a larger gap. + break + } + // If new max is smaller than the old maxGap, and this gap used + // to be the maxGap of its parent, iterate parent's children + // and calculate parent's new maxGap.(It's probable that parent + // has two children with the old maxGap, but we need to check it anyway.) + parentNewMax := p.calculateMaxGapInternal() + if p.maxGap.Get() == parentNewMax { + // p and its ancestors still contain a gap of at least equal size. + break + } + // If p's new maxGap differs from the old one, propagate this update. + p.maxGap.Set(parentNewMax) + } +} + +// updateMaxGapLocal updates maxGap of the calling node solely with no +// propagation to ancestor nodes. +// +// Precondition: trackGaps must be 1. +func (n *node) updateMaxGapLocal() { + if !n.hasChildren { + // Leaf node iterates its gaps. + n.maxGap.Set(n.calculateMaxGapLeaf()) + } else { + // Non-leaf node iterates its children. + n.maxGap.Set(n.calculateMaxGapInternal()) + } +} + +// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the +// max. +// +// Preconditions: n must be a leaf node. +func (n *node) calculateMaxGapLeaf() Key { + max := GapIterator{n, 0}.Range().Length() + for i := 1; i <= n.nrSegments; i++ { + if current := (GapIterator{n, i}).Range().Length(); current > max { + max = current + } + } + return max +} + +// calculateMaxGapInternal iterates children's maxGap within an internal node n +// and calculate the max. +// +// Preconditions: n must be a non-leaf node. +func (n *node) calculateMaxGapInternal() Key { + max := n.children[0].maxGap.Get() + for i := 1; i <= n.nrSegments; i++ { + if current := n.children[i].maxGap.Get(); current > max { + max = current + } + } + return max +} + +// searchFirstLargeEnoughGap returns the first gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := 0; i <= n.nrSegments; i++ { + if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := 0; i <= n.nrSegments; i++ { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + +// searchLastLargeEnoughGap returns the last gap having at least minSize length +// in the subtree rooted by n. If not found, return a terminal gap iterator. +func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator { + if n.maxGap.Get() < minSize { + return GapIterator{} + } + if n.hasChildren { + for i := n.nrSegments; i >= 0; i-- { + if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } + } else { + for i := n.nrSegments; i >= 0; i-- { + currentGap := GapIterator{n, i} + if currentGap.Range().Length() >= minSize { + return currentGap + } + } + } + panic(fmt.Sprintf("invalid maxGap in %v", n)) +} + // A Iterator is conceptually one of: // // - A pointer to a segment in a set; or @@ -1243,6 +1469,122 @@ func (gap GapIterator) NextGap() GapIterator { return seg.NextGap() } +// NextLargeEnoughGap returns the iterated gap's first next gap with larger +// length than minSize. If not found, return a terminal gap iterator (does NOT +// include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.NextSegment().node + gap.index = 0 + return gap.nextLargeEnoughGapHelper(minSize) + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the trailing gap of a non-leaf node. +func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the trailing one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate subsequent gaps. + gap.index++ + for gap.index <= gap.node.nrSegments { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index++ + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == gap.node.nrSegments { + // If gap is the trailing gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.nextLargeEnoughGapHelper(minSize) +} + +// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or +// equal length than minSize. If not found, return a terminal gap iterator +// (does NOT include this gap itself). +// +// Precondition: trackGaps must be 1. +func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator { + if trackGaps != 1 { + panic("set is not tracking gaps") + } + if gap.node != nil && gap.node.hasChildren && gap.index == 0 { + // If gap is the first gap of an non-leaf node, + // translate it to the equivalent gap on leaf level. + gap.node = gap.PrevSegment().node + gap.index = gap.node.nrSegments + return gap.prevLargeEnoughGapHelper(minSize) + } + return gap.prevLargeEnoughGapHelper(minSize) +} + +// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap +// to do the real recursions. +// +// Preconditions: gap is NOT the first gap of a non-leaf node. +func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator { + // Crawl up the tree if no large enough gap in current node or the + // current gap is the first one on leaf level. + for gap.node != nil && + (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) { + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + // If no large enough gap throughout the whole set, return a terminal + // gap iterator. + if gap.node == nil { + return GapIterator{} + } + // Iterate previous gaps. + gap.index-- + for gap.index >= 0 { + if gap.node.hasChildren { + if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() { + return largeEnoughGap + } + } else { + if gap.Range().Length() >= minSize { + return gap + } + } + gap.index-- + } + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + if gap.node != nil && gap.index == 0 { + // If gap is the first gap of a non-leaf node, crawl up to + // parent again and do recursion. + gap.node, gap.index = gap.node.parent, gap.node.parentIndex + } + return gap.prevLargeEnoughGapHelper(minSize) +} + // segmentBeforePosition returns the predecessor segment of the position given // by n.children[i], which may or may not contain a child. If no such segment // exists, segmentBeforePosition returns a terminal iterator. @@ -1271,7 +1613,7 @@ func segmentAfterPosition(n *node, i int) Iterator { func zeroValueSlice(slice []Value) { // TODO(jamieliu): check if Go is actually smart enough to optimize a - // ClearValue that assigns nil to a memset here + // ClearValue that assigns nil to a memset here. for i := range slice { Functions{}.ClearValue(&slice[i]) } @@ -1310,7 +1652,15 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i)) } buf.WriteString(prefix) - buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + if n.hasChildren { + if trackGaps != 0 { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get())) + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } + } else { + buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i])) + } } if child := n.children[n.nrSegments]; child != nil { child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments)) @@ -1362,3 +1712,43 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { } return nil } + +// segmentTestCheck returns an error if s is incorrectly sorted, does not +// contain exactly expectedSegments segments, or contains a segment which +// fails the passed check. +// +// This should be used only for testing, and has been added to this package for +// templating convenience. +func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error { + havePrev := false + prev := Key(0) + nrSegments := 0 + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + next := seg.Start() + if havePrev && prev >= next { + return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) + } + if segFunc != nil { + if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil { + return err + } + } + prev = next + havePrev = true + nrSegments++ + } + if nrSegments != expectedSegments { + return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) + } + return nil +} + +// countSegments counts the number of segments in the set. +// +// Similar to Check, this should only be used for testing. +func (s *Set) countSegments() (segments int) { + for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + segments++ + } + return segments +} diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index f2d8462d8..131bf09b9 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -29,10 +29,28 @@ go_template_instance( }, ) +go_template_instance( + name = "gap_set", + out = "gap_set.go", + consts = { + "trackGaps": "1", + }, + package = "segment", + prefix = "gap", + template = "//pkg/segment:generic_set", + types = { + "Key": "int", + "Range": "Range", + "Value": "int", + "Functions": "gapSetFunctions", + }, +) + go_library( name = "segment", testonly = 1, srcs = [ + "gap_set.go", "int_range.go", "int_set.go", "set_functions.go", diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go index 97b16c158..85fa19096 100644 --- a/pkg/segment/test/segment_test.go +++ b/pkg/segment/test/segment_test.go @@ -17,6 +17,7 @@ package segment import ( "fmt" "math/rand" + "reflect" "testing" ) @@ -32,61 +33,65 @@ const ( // valueOffset is the difference between the value and start of test // segments. valueOffset = 100000 + + // intervalLength is the interval used by random gap tests. + intervalLength = 10 ) func shuffle(xs []int) { - for i := range xs { - j := rand.Intn(i + 1) - xs[i], xs[j] = xs[j], xs[i] - } + rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] }) } -func randPermutation(size int) []int { +func randIntervalPermutation(size int) []int { p := make([]int, size) for i := range p { - p[i] = i + p[i] = intervalLength * i } shuffle(p) return p } -// checkSet returns an error if s is incorrectly sorted, does not contain -// exactly expectedSegments segments, or contains a segment for which val != -// key + valueOffset. -func checkSet(s *Set, expectedSegments int) error { - havePrev := false - prev := 0 - nrSegments := 0 - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - next := seg.Start() - if havePrev && prev >= next { - return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) - } - if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want) - } - prev = next - havePrev = true - nrSegments++ - } - if nrSegments != expectedSegments { - return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments) +// validate can be passed to Check. +func validate(nr int, r Range, v int) error { + if got, want := v, r.Start+valueOffset; got != want { + return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want) } return nil } -// countSegmentsIn returns the number of segments in s. -func countSegmentsIn(s *Set) int { - var count int - for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - count++ +// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well +// maintained. +func checkSetMaxGap(s *gapSet) error { + n := s.root + return checkNodeMaxGap(&n) +} + +// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is +// not well maintained. +func checkNodeMaxGap(n *gapnode) error { + var max int + if !n.hasChildren { + max = n.calculateMaxGapLeaf() + } else { + for i := 0; i <= n.nrSegments; i++ { + child := n.children[i] + if err := checkNodeMaxGap(child); err != nil { + return err + } + if temp := child.maxGap.Get(); i == 0 || temp > max { + max = temp + } + } + } + if max != n.maxGap.Get() { + return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap) } - return count + return nil } func TestAddRandom(t *testing.T) { var s Set - order := randPermutation(testSize) + order := rand.Perm(testSize) var nrInsertions int for i, j := range order { if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { @@ -94,12 +99,12 @@ func TestAddRandom(t *testing.T) { break } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -115,7 +120,156 @@ func TestRemoveRandom(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(testSize) + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapAddRandom(t *testing.T) { + var s gapSet + order := rand.Perm(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithRandomInterval(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + var nrInsertions int + for i, j := range order { + if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + nrInsertions++ + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order[:nrInsertions]) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapAddRandomWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize) + nrInsertions := 1 + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + if got, want := s.countSegments(), nrInsertions; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Insertion order: %v", order) + t.Logf("Set contents:\n%v", &s) + } +} + +func TestMaxGapRemoveRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := rand.Perm(testSize) + var nrRemovals int + for i, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + t.Errorf("Iteration %d: failed to find segment with key %d", i, j) + break + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { + t.Errorf("Iteration %d: %v", i, err) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if got, want := s.countSegments(), testSize-nrRemovals; got != want { + t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestMaxGapRemoveHalfRandom(t *testing.T) { + var s gapSet + for i := 0; i < testSize; i++ { + if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) { + t.Fatalf("Failed to insert segment %d", i) + } + } + order := randIntervalPermutation(testSize) + order = order[:testSize/2] var nrRemovals int for i, j := range order { seg := s.FindSegment(j) @@ -123,14 +277,19 @@ func TestRemoveRandom(t *testing.T) { t.Errorf("Iteration %d: failed to find segment with key %d", i, j) break } + temprange := seg.Range() s.Remove(seg) nrRemovals++ - if err := checkSet(&s, testSize-nrRemovals); err != nil { + if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } } - if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want { + if got, want := s.countSegments(), testSize-nrRemovals; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -140,6 +299,148 @@ func TestRemoveRandom(t *testing.T) { } } +func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + intervalLength}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + var nrRemovals int + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + nrRemovals++ + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + if t.Failed() { + t.Logf("Removal order: %v", order[:nrRemovals]) + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestNextLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + +func TestPrevLargeEnoughGap(t *testing.T) { + var s gapSet + order := randIntervalPermutation(testSize * 2) + order = order[:testSize] + for i, j := range order { + if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) { + t.Errorf("Iteration %d: failed to insert segment with key %d", i, j) + break + } + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When inserting %d: %v", j, err) + break + } + } + end := s.LastSegment().End() + shuffle(order) + order = order[:testSize/2] + for _, j := range order { + seg := s.FindSegment(j) + if !seg.Ok() { + continue + } + temprange := seg.Range() + s.Remove(seg) + if err := checkSetMaxGap(&s); err != nil { + t.Errorf("When removing %v: %v", temprange, err) + break + } + } + minSize := 7 + var gapArr1 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) { + if gap.Range().Length() < minSize { + t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize) + } else { + gapArr1 = append(gapArr1, gap.Range().Start) + } + } + var gapArr2 []int + for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() { + if gap.Range().Length() >= minSize { + gapArr2 = append(gapArr2, gap.Range().Start) + } + } + if !reflect.DeepEqual(gapArr2, gapArr1) { + t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2) + } + if t.Failed() { + t.Logf("Set contents:\n%v", &s) + t.FailNow() + } +} + func TestAddSequentialAdjacent(t *testing.T) { var s Set var nrInsertions int @@ -148,12 +449,12 @@ func TestAddSequentialAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -202,12 +503,12 @@ func TestAddSequentialNonAdjacent(t *testing.T) { t.Fatalf("Failed to insert segment %d", i) } nrInsertions++ - if err := checkSet(&s, nrInsertions); err != nil { + if err := s.segmentTestCheck(nrInsertions, validate); err != nil { t.Errorf("Iteration %d: %v", i, err) break } } - if got, want := countSegmentsIn(&s), nrInsertions; got != want { + if got, want := s.countSegments(), nrInsertions; got != want { t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want) } if t.Failed() { @@ -293,7 +594,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -351,7 +652,7 @@ Tests: var i int for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { if i > len(test.final) { - t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s) + t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s) continue Tests } if got, want := seg.Range(), test.final[i]; got != want { @@ -378,7 +679,7 @@ func benchmarkAddSequential(b *testing.B, size int) { } func benchmarkAddRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -416,7 +717,7 @@ func benchmarkFindRandom(b *testing.B, size int) { b.Fatalf("Failed to insert segment %d", i) } } - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -470,7 +771,7 @@ func benchmarkAddFindRemoveSequential(b *testing.B, size int) { } func benchmarkAddFindRemoveRandom(b *testing.B, size int) { - order := randPermutation(size) + order := rand.Perm(size) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go index bcddb39bb..7cd895cc7 100644 --- a/pkg/segment/test/set_functions.go +++ b/pkg/segment/test/set_functions.go @@ -14,21 +14,16 @@ package segment -// Basic numeric constants that we define because the math package doesn't. -// TODO(nlacasse): These should be Math.MaxInt64/MinInt64? -const ( - maxInt = int(^uint(0) >> 1) - minInt = -maxInt - 1 -) - type setFunctions struct{} -func (setFunctions) MinKey() int { - return minInt +// MinKey returns the minimum key for the set. +func (s setFunctions) MinKey() int { + return -s.MaxKey() - 1 } +// MaxKey returns the maximum key for the set. func (setFunctions) MaxKey() int { - return maxInt + return int(^uint(0) >> 1) } func (setFunctions) ClearValue(*int) {} @@ -40,3 +35,20 @@ func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) { func (setFunctions) Split(_ Range, val int, _ int) (int, int) { return val, val } + +type gapSetFunctions struct { + setFunctions +} + +// MinKey is adjusted to make sure no add overflow would happen in test cases. +// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length(). +// +// Normally Keys should be unsigned to avoid these issues. +func (s gapSetFunctions) MinKey() int { + return s.setFunctions.MinKey() / 2 +} + +// MaxKey returns the maximum key for the set. +func (s gapSetFunctions) MaxKey() int { + return s.setFunctions.MaxKey() / 2 +} diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index e74275d2d..2c5d14be5 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -16,14 +16,12 @@ go_library( ], deps = [ "//pkg/abi/linux", - "//pkg/context", "//pkg/fd", - "//pkg/fspath", "//pkg/log", "//pkg/sentry/fdimport", "//pkg/sentry/fs", "//pkg/sentry/fs/host", - "//pkg/sentry/fsbridge", + "//pkg/sentry/fs/user", "//pkg/sentry/fsimpl/host", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -35,7 +33,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/sync", - "//pkg/syserror", "//pkg/tcpip/link/sniffer", "//pkg/urpc", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 2ed17ee09..1bae7cfaf 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -18,7 +18,6 @@ import ( "bytes" "encoding/json" "fmt" - "path" "sort" "strings" "text/tabwriter" @@ -26,13 +25,10 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fdimport" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" - "gvisor.dev/gvisor/pkg/sentry/fsbridge" + "gvisor.dev/gvisor/pkg/sentry/fs/user" hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -40,7 +36,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/urpc" ) @@ -108,6 +103,9 @@ type ExecArgs struct { // String prints the arguments as a string. func (args ExecArgs) String() string { + if len(args.Argv) == 0 { + return args.Filename + } a := make([]string, len(args.Argv)) copy(a, args.Argv) if args.Filename != "" { @@ -180,42 +178,30 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI } ctx := initArgs.NewContext(proc.Kernel) - if initArgs.Filename == "" { - if kernel.VFS2Enabled { - // Get the full path to the filename from the PATH env variable. - if initArgs.MountNamespaceVFS2 == nil { - // Set initArgs so that 'ctx' returns the namespace. - // - // MountNamespaceVFS2 adds a reference to the namespace, which is - // transferred to the new process. - initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2() - } + if kernel.VFS2Enabled { + // Get the full path to the filename from the PATH env variable. + if initArgs.MountNamespaceVFS2 == nil { + // Set initArgs so that 'ctx' returns the namespace. + // + // MountNamespaceVFS2 adds a reference to the namespace, which is + // transferred to the new process. + initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2() + } + } else { + if initArgs.MountNamespace == nil { + // Set initArgs so that 'ctx' returns the namespace. + initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace() - paths := fs.GetPath(initArgs.Envv) - vfsObj := proc.Kernel.VFS() - file, err := ResolveExecutablePath(ctx, vfsObj, initArgs.WorkingDirectory, initArgs.Argv[0], paths) - if err != nil { - return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err) - } - initArgs.File = fsbridge.NewVFSFile(file) - } else { - // Get the full path to the filename from the PATH env variable. - paths := fs.GetPath(initArgs.Envv) - if initArgs.MountNamespace == nil { - // Set initArgs so that 'ctx' returns the namespace. - initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace() - - // initArgs must hold a reference on MountNamespace, which will - // be donated to the new process in CreateProcess. - initArgs.MountNamespace.IncRef() - } - f, err := initArgs.MountNamespace.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths) - if err != nil { - return nil, 0, nil, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err) - } - initArgs.Filename = f + // initArgs must hold a reference on MountNamespace, which will + // be donated to the new process in CreateProcess. + initArgs.MountNamespace.IncRef() } } + resolved, err := user.ResolveExecutablePath(ctx, &initArgs) + if err != nil { + return nil, 0, nil, nil, err + } + initArgs.Filename = resolved fds := make([]int, len(args.FilePayload.Files)) for i, file := range args.FilePayload.Files { @@ -428,67 +414,3 @@ func ttyName(tty *kernel.TTY) string { } return fmt.Sprintf("pts/%d", tty.Index) } - -// ResolveExecutablePath resolves the given executable name given a set of -// paths that might contain it. -func ResolveExecutablePath(ctx context.Context, vfsObj *vfs.VirtualFilesystem, wd, name string, paths []string) (*vfs.FileDescription, error) { - root := vfs.RootFromContext(ctx) - defer root.DecRef() - creds := auth.CredentialsFromContext(ctx) - - // Absolute paths can be used directly. - if path.IsAbs(name) { - return openExecutable(ctx, vfsObj, creds, root, name) - } - - // Paths with '/' in them should be joined to the working directory, or - // to the root if working directory is not set. - if strings.IndexByte(name, '/') > 0 { - if len(wd) == 0 { - wd = "/" - } - if !path.IsAbs(wd) { - return nil, fmt.Errorf("working directory %q must be absolute", wd) - } - return openExecutable(ctx, vfsObj, creds, root, path.Join(wd, name)) - } - - // Otherwise, we must lookup the name in the paths, starting from the - // calling context's root directory. - for _, p := range paths { - if !path.IsAbs(p) { - // Relative paths aren't safe, no one should be using them. - log.Warningf("Skipping relative path %q in $PATH", p) - continue - } - - binPath := path.Join(p, name) - f, err := openExecutable(ctx, vfsObj, creds, root, binPath) - if err != nil { - return nil, err - } - if f == nil { - continue // Not found/no access. - } - return f, nil - } - return nil, syserror.ENOENT -} - -func openExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry, path string) (*vfs.FileDescription, error) { - pop := vfs.PathOperation{ - Root: root, - Start: root, // binPath is absolute, Start can be anything. - Path: fspath.Parse(path), - FollowFinalSymlink: true, - } - opts := &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - FileExec: true, - } - f, err := vfsObj.OpenAt(ctx, creds, &pop, opts) - if err == syserror.ENOENT || err == syserror.EACCES { - return nil, nil - } - return f, err -} diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 846252c89..2a278fbe3 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -310,7 +310,6 @@ func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error if !f.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append) // Handle append mode. if f.Flags().Append { @@ -355,7 +354,6 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64 // offset." unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append) defer unlockAppendMu() - if f.Flags().Append { if err := f.offsetForAppend(ctx, &offset); err != nil { return 0, err @@ -374,9 +372,10 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64 return f.FileOperations.Write(ctx, f, src, offset) } -// offsetForAppend sets the given offset to the end of the file. +// offsetForAppend atomically sets the given offset to the end of the file. // -// Precondition: the file.Dirent.Inode.appendMu mutex should be held for writing. +// Precondition: the file.Dirent.Inode.appendMu mutex should be held for +// writing. func (f *File) offsetForAppend(ctx context.Context, offset *int64) error { uattr, err := f.Dirent.Inode.UnstableAttr(ctx) if err != nil { @@ -386,7 +385,7 @@ func (f *File) offsetForAppend(ctx context.Context, offset *int64) error { } // Update the offset. - *offset = uattr.Size + atomic.StoreInt64(offset, uattr.Size) return nil } diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index bdba6efe5..d2dbff268 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -42,9 +42,10 @@ // Dirent.dirMu // Dirent.mu // DirentCache.mu -// Locks in InodeOperations implementations or overlayEntry // Inode.Watches.mu (see `Inotify` for other lock ordering) // MountSource.mu +// Inode.appendMu +// Locks in InodeOperations implementations or overlayEntry // // If multiple Dirent or MountSource locks must be taken, locks in the parent must be // taken before locks in their children. diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index 6564fd0c6..dd6f5aba6 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/usage" ) // FrameRefSetFunctions implements segment.Functions for FrameRefSet. @@ -49,3 +50,42 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } + +// IncRefAndAccount adds a reference on the range fr. All newly inserted segments +// are accounted as host page cache memory mappings. +func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { + seg, gap := refs.Find(fr.Start) + for { + switch { + case seg.Ok() && seg.Start() < fr.End: + seg = refs.Isolate(seg, fr) + seg.SetValue(seg.Value() + 1) + seg, gap = seg.NextNonEmpty() + case gap.Ok() && gap.Start() < fr.End: + newRange := gap.Range().Intersect(fr) + usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) + seg, gap = refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() + default: + refs.MergeAdjacent(fr) + return + } + } +} + +// DecRefAndAccount removes a reference on the range fr and untracks segments +// that are removed from memory accounting. +func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { + seg := refs.FindSegment(fr.Start) + + for seg.Ok() && seg.Start() < fr.End { + seg = refs.Isolate(seg, fr) + if old := seg.Value(); old == 1 { + usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) + seg = refs.Remove(seg).NextSegment() + } else { + seg.SetValue(old - 1) + seg = seg.NextSegment() + } + } + refs.MergeAdjacent(fr) +} diff --git a/pkg/sentry/fs/g3doc/.gitignore b/pkg/sentry/fs/g3doc/.gitignore new file mode 100644 index 000000000..2d19fc766 --- /dev/null +++ b/pkg/sentry/fs/g3doc/.gitignore @@ -0,0 +1 @@ +*.html diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md new file mode 100644 index 000000000..2ca84dd74 --- /dev/null +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -0,0 +1,263 @@ +# Foreword + +This document describes an on-going project to support FUSE filesystems within +the sentry. This is intended to become the final documentation for this +subsystem, and is therefore written in the past tense. However FUSE support is +currently incomplete and the document will be updated as things progress. + +# FUSE: Filesystem in Userspace + +The sentry supports dispatching filesystem operations to a FUSE server, allowing +FUSE filesystem to be used with a sandbox. + +## Overview + +FUSE has two main components: + +1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards + filesystem operations (usually initiated by syscalls) to the server. + +2. A server, which is a userspace daemon that implements the actual filesystem. + +The sentry implements the client component, which allows a server daemon running +within the sandbox to implement a filesystem within the sandbox. + +A FUSE filesystem is initialized with `mount(2)`, typically with the help of a +utility like `fusermount(1)`. Various mount options exist for establishing +ownership and access permissions on the filesystem, but the most important mount +option is a file descriptor used to establish communication between the client +and server. + +The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation, +the client and server use the FUSE protocol described in `fuse(4)` to service +filesystem operations. See the "Protocol" section below for more information +about this protocol. The core of the sentry support for FUSE is the client-side +implementation of this protocol. + +## FUSE in the Sentry + +The sentry's FUSE client targets VFS2 and has the following components: + +- An implementation of `/dev/fuse`. + +- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting + VFS2, one point of contention may be the lack of inodes in VFS2. We can + tentatively implement a kernfs-based filesystem to bridge the gap in APIs. + The kernfs base functionality can serve the role of the Linux inode cache + and, the filesystem can map VFS2 syscalls to kernfs inode operations; see + the `kernfs.Inode` interface. + +The FUSE protocol lends itself well to marshaling with `go_marshal`. The various +request and response packets can be defined in the ABI package and converted to +and from the wire format using `go_marshal`. + +### Design Goals + +- While filesystem performance is always important, the sentry's FUSE support + is primarily concerned with compatibility, with performance as a secondary + concern. + +- Avoiding deadlocks from a hung server daemon. + +- Consider the potential for denial of service from a malicious server daemon. + Protecting itself from userspace is already a design goal for the sentry, + but needs additional consideration for FUSE. Normally, an operating system + doesn't rely on userspace to make progress with filesystem operations. Since + this changes with FUSE, it opens up the possibility of creating a chain of + dependencies controlled by userspace, which could affect an entire sandbox. + For example: a FUSE op can block a syscall, which could be holding a + subsystem lock, which can then block another task goroutine. + +### Milestones + +Below are some broad goals to aim for while implementing FUSE in the sentry. +Many FUSE ops can be grouped into broad categories of functionality, and most +ops can be implemented in parallel. + +#### Minimal client that can mount a trivial FUSE filesystem. + +- Implement `/dev/fuse` - a character device used to establish an FD for + communication between the sentry and the server daemon. + +- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. + +#### Read-only mount with basic file operations + +- Implement the majority of file, directory and file descriptor FUSE ops. For + this milestone, we can skip uncommon or complex operations like mmap, mknod, + file locking, poll, and extended attributes. We can stub these out along + with any ops that modify the filesystem. The exact list of required ops are + to be determined, but the goal is to mount a real filesystem as read-only, + and be able to read contents from the filesystem in the sentry. + +#### Full read-write support + +- Implement the remaining FUSE ops and decide if we can omit rarely used + operations like ioctl. + +# Appendix + +## FUSE Protocol + +The FUSE protocol is a request-response protocol. All requests are initiated by +the client. The wire-format for the protocol is raw C structs serialized to +memory. + +All FUSE requests begin with the following request header: + +```c +struct fuse_in_header { + uint32_t len; // Length of the request, including this header. + uint32_t opcode; // Requested operation. + uint64_t unique; // A unique identifier for this request. + uint64_t nodeid; // ID of the filesystem object being operated on. + uint32_t uid; // UID of the requesting process. + uint32_t gid; // GID of the requesting process. + uint32_t pid; // PID of the requesting process. + uint32_t padding; +}; +``` + +The request is then followed by a payload specific to the `opcode`. + +All responses begin with this response header: + +```c +struct fuse_out_header { + uint32_t len; // Length of the response, including this header. + int32_t error; // Status of the request, 0 if success. + uint64_t unique; // The unique identifier from the corresponding request. +}; +``` + +The response payload also depends on the request `opcode`. If `error != 0`, the +response payload must be empty. + +### Operations + +The following is a list of all FUSE operations used in `fuse_in_header.opcode` +as of Linux v4.4, and a brief description of their purpose. These are defined in +`uapi/linux/fuse.h`. Many of these have a corresponding request and response +payload struct; `fuse(4)` has details for some of these. We also note how these +operations map to the sentry virtual filesystem. + +#### FUSE meta-operations + +These operations are specific to FUSE and don't have a corresponding action in a +generic filesystem. + +- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the + first message sent by the client after mount. This is used for version and + feature negotiation. This is related to `mount(2)`. +- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`. +- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the + `fuse_in_header.unique` value provided in the corresponding request header. + The client can send at most one of these per request, and will enter an + uninterruptible wait for a reply. The server is expected to reply promptly. +- `FUSE_FORGET`: A hint to the server that server should evict the indicate + node from any caches. This is wired up to `(struct + super_operations).evict_inode` in Linux, which is in turned hooked as the + inode cache shrinker which is typically triggered by system memory pressure. +- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`. + +#### Filesystem Syscalls + +These FUSE ops map directly to an equivalent filesystem syscall, or family of +syscalls. The relevant syscalls have a similar name to the operation, unless +otherwise noted. + +Node creation: + +- `FUSE_MKNOD` +- `FUSE_MKDIR` +- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which + atomically creates and opens a node. + +Node attributes and extended attributes: + +- `FUSE_GETATTR` +- `FUSE_SETATTR` +- `FUSE_SETXATTR` +- `FUSE_GETXATTR` +- `FUSE_LISTXATTR` +- `FUSE_REMOVEXATTR` + +Node link manipulation: + +- `FUSE_READLINK` +- `FUSE_LINK` +- `FUSE_SYMLINK` +- `FUSE_UNLINK` + +Directory operations: + +- `FUSE_RMDIR` +- `FUSE_RENAME` +- `FUSE_RENAME2` +- `FUSE_OPENDIR`: `open(2)` for directories. +- `FUSE_RELEASEDIR`: `close(2)` for directories. +- `FUSE_READDIR` +- `FUSE_READDIRPLUS` +- `FUSE_FSYNCDIR`: `fsync(2)` for directories. +- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is + reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path + component to a node. However the returned identifier is opaque to the + client. The server must remember this mapping, as this is how the client + will reference the node in the future. + +File operations: + +- `FUSE_OPEN`: `open(2)` for files. +- `FUSE_RELEASE`: `close(2)` for files. +- `FUSE_FSYNC` +- `FUSE_FALLOCATE` +- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`. +- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`. + +File locking: + +- `FUSE_GETLK` +- `FUSE_SETLK` +- `FUSE_SETLKW` +- `FUSE_COPY_FILE_RANGE` + +File descriptor operations: + +- `FUSE_IOCTL` +- `FUSE_POLL` +- `FUSE_LSEEK` + +Filesystem operations: + +- `FUSE_STATFS` + +#### Permissions + +- `FUSE_ACCESS` is used to check if a node is accessible, as part of many + syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the + sentry. + +#### I/O Operations + +These ops are used to read and write file pages. They're used to implement both +I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`. + +- `FUSE_READ` +- `FUSE_WRITE` + +#### Miscellaneous + +- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is + closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)` + syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the + sentry. +- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed. +- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?] + +# References + +- [fuse(4) Linux manual page](https://www.man7.org/linux/man-pages/man4/fuse.4.html) +- [Linux kernel FUSE documentation](https://www.kernel.org/doc/html/latest/filesystems/fuse.html) +- [The reference implementation of the Linux FUSE (Filesystem in Userspace) + interface](https://github.com/libfuse/libfuse) +- [The kernel interface of FUSE](https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/fuse.h) diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index a016c896e..51d7368a1 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -640,7 +640,7 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, // WriteOut implements fs.InodeOperations.WriteOut. func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error { - if !i.session().cachePolicy.cacheUAttrs(inode) { + if inode.MountSource.Flags.ReadOnly || !i.session().cachePolicy.cacheUAttrs(inode) { return nil } diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 62f1246aa..fbfba1b58 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -368,6 +368,9 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, // WriteOut implements fs.InodeOperations.WriteOut. func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error { + if inode.MountSource.Flags.ReadOnly { + return nil + } // Have we been using host kernel metadata caches? if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) { // Then the metadata is already up to date on the host. diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index b414ddaee..3f2bd0e87 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -17,13 +17,9 @@ package fs import ( "fmt" "math" - "path" - "strings" "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" @@ -625,71 +621,3 @@ func (mns *MountNamespace) SyncAll(ctx context.Context) { defer mns.mu.Unlock() mns.root.SyncAll(ctx) } - -// ResolveExecutablePath resolves the given executable name given a set of -// paths that might contain it. -func (mns *MountNamespace) ResolveExecutablePath(ctx context.Context, wd, name string, paths []string) (string, error) { - // Absolute paths can be used directly. - if path.IsAbs(name) { - return name, nil - } - - // Paths with '/' in them should be joined to the working directory, or - // to the root if working directory is not set. - if strings.IndexByte(name, '/') > 0 { - if wd == "" { - wd = "/" - } - if !path.IsAbs(wd) { - return "", fmt.Errorf("working directory %q must be absolute", wd) - } - return path.Join(wd, name), nil - } - - // Otherwise, We must lookup the name in the paths, starting from the - // calling context's root directory. - root := RootFromContext(ctx) - if root == nil { - // Caller has no root. Don't bother traversing anything. - return "", syserror.ENOENT - } - defer root.DecRef() - for _, p := range paths { - binPath := path.Join(p, name) - traversals := uint(linux.MaxSymlinkTraversals) - d, err := mns.FindInode(ctx, root, nil, binPath, &traversals) - if err == syserror.ENOENT || err == syserror.EACCES { - // Didn't find it here. - continue - } - if err != nil { - return "", err - } - defer d.DecRef() - - // Check that it is a regular file. - if !IsRegular(d.Inode.StableAttr) { - continue - } - - // Check whether we can read and execute the found file. - if err := d.Inode.CheckPermission(ctx, PermMask{Read: true, Execute: true}); err != nil { - log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err) - continue - } - return path.Join("/", p, name), nil - } - return "", syserror.ENOENT -} - -// GetPath returns the PATH as a slice of strings given the environment -// variables. -func GetPath(env []string) []string { - const prefix = "PATH=" - for _, e := range env { - if strings.HasPrefix(e, prefix) { - return strings.Split(strings.TrimPrefix(e, prefix), ":") - } - } - return nil -} diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD index f37f979f1..66e949c95 100644 --- a/pkg/sentry/fs/user/BUILD +++ b/pkg/sentry/fs/user/BUILD @@ -4,15 +4,21 @@ package(licenses = ["notice"]) go_library( name = "user", - srcs = ["user.go"], + srcs = [ + "path.go", + "user.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/log", "//pkg/sentry/fs", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", + "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go new file mode 100644 index 000000000..397e96045 --- /dev/null +++ b/pkg/sentry/fs/user/path.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package user + +import ( + "fmt" + "path" + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// ResolveExecutablePath resolves the given executable name given the working +// dir and environment. +func ResolveExecutablePath(ctx context.Context, args *kernel.CreateProcessArgs) (string, error) { + name := args.Filename + if len(name) == 0 { + if len(args.Argv) == 0 { + return "", fmt.Errorf("no filename or command provided") + } + name = args.Argv[0] + } + + // Absolute paths can be used directly. + if path.IsAbs(name) { + return name, nil + } + + // Paths with '/' in them should be joined to the working directory, or + // to the root if working directory is not set. + if strings.IndexByte(name, '/') > 0 { + wd := args.WorkingDirectory + if wd == "" { + wd = "/" + } + if !path.IsAbs(wd) { + return "", fmt.Errorf("working directory %q must be absolute", wd) + } + return path.Join(wd, name), nil + } + + // Otherwise, We must lookup the name in the paths. + paths := getPath(args.Envv) + if kernel.VFS2Enabled { + f, err := resolveVFS2(ctx, args.Credentials, args.MountNamespaceVFS2, paths, name) + if err != nil { + return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err) + } + return f, nil + } + + f, err := resolve(ctx, args.MountNamespace, paths, name) + if err != nil { + return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err) + } + return f, nil +} + +func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name string) (string, error) { + root := fs.RootFromContext(ctx) + if root == nil { + // Caller has no root. Don't bother traversing anything. + return "", syserror.ENOENT + } + defer root.DecRef() + for _, p := range paths { + if !path.IsAbs(p) { + // Relative paths aren't safe, no one should be using them. + log.Warningf("Skipping relative path %q in $PATH", p) + continue + } + + binPath := path.Join(p, name) + traversals := uint(linux.MaxSymlinkTraversals) + d, err := mns.FindInode(ctx, root, nil, binPath, &traversals) + if err == syserror.ENOENT || err == syserror.EACCES { + // Didn't find it here. + continue + } + if err != nil { + return "", err + } + defer d.DecRef() + + // Check that it is a regular file. + if !fs.IsRegular(d.Inode.StableAttr) { + continue + } + + // Check whether we can read and execute the found file. + if err := d.Inode.CheckPermission(ctx, fs.PermMask{Read: true, Execute: true}); err != nil { + log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err) + continue + } + return path.Join("/", p, name), nil + } + + // Couldn't find it. + return "", syserror.ENOENT +} + +func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) { + root := mns.Root() + defer root.DecRef() + for _, p := range paths { + if !path.IsAbs(p) { + // Relative paths aren't safe, no one should be using them. + log.Warningf("Skipping relative path %q in $PATH", p) + continue + } + + binPath := path.Join(p, name) + pop := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(binPath), + FollowFinalSymlink: true, + } + opts := &vfs.OpenOptions{ + FileExec: true, + Flags: linux.O_RDONLY, + } + dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts) + if err == syserror.ENOENT || err == syserror.EACCES { + // Didn't find it here. + continue + } + if err != nil { + return "", err + } + dentry.DecRef() + + return binPath, nil + } + + // Couldn't find it. + return "", syserror.ENOENT +} + +// getPath returns the PATH as a slice of strings given the environment +// variables. +func getPath(env []string) []string { + const prefix = "PATH=" + for _, e := range env { + if strings.HasPrefix(e, prefix) { + return strings.Split(strings.TrimPrefix(e, prefix), ":") + } + } + return nil +} diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go index fe7f67c00..f4d525523 100644 --- a/pkg/sentry/fs/user/user.go +++ b/pkg/sentry/fs/user/user.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package user contains methods for resolving filesystem paths based on the +// user and their environment. package user import ( diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index e201801d6..f7bc325d1 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -27,8 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - const ( // canonMaxBytes is the number of bytes that fit into a single line of // terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE @@ -445,5 +443,3 @@ func (l *lineDiscipline) peek(b []byte) int { } return size } - -// LINT.ThenChange(../../fs/tty/line_discipline.go) diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 04a292927..7a7ce5d81 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -27,8 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - // masterInode is the inode for the master end of the Terminal. type masterInode struct { kernfs.InodeAttrs @@ -222,5 +220,3 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { unimpl.EmitUnimplementedEvent(ctx) } } - -// LINT.ThenChange(../../fs/tty/master.go) diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go index 29a6be858..dffb4232c 100644 --- a/pkg/sentry/fsimpl/devpts/queue.go +++ b/pkg/sentry/fsimpl/devpts/queue.go @@ -25,8 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - // waitBufMaxBytes is the maximum size of a wait buffer. It is based on // TTYB_DEFAULT_MEM_LIMIT. const waitBufMaxBytes = 131072 @@ -236,5 +234,3 @@ func (q *queue) waitBufAppend(b []byte) { q.waitBuf = append(q.waitBuf, b) q.waitBufLen += uint64(len(b)) } - -// LINT.ThenChange(../../fs/tty/queue.go) diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index 0a98dc896..526cd406c 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -26,8 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// LINT.IfChange - // slaveInode is the inode for the slave end of the Terminal. type slaveInode struct { kernfs.InodeAttrs @@ -182,5 +180,3 @@ func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() return sfd.inode.Stat(fs, opts) } - -// LINT.ThenChange(../../fs/tty/slave.go) diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go index b44e673d8..7d2781c54 100644 --- a/pkg/sentry/fsimpl/devpts/terminal.go +++ b/pkg/sentry/fsimpl/devpts/terminal.go @@ -22,8 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// LINT.IfChanges - // Terminal is a pseudoterminal. // // +stateify savable @@ -120,5 +118,3 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { } return tm.slaveKTTY } - -// LINT.ThenChange(../../fs/tty/terminal.go) diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index bfbd7c3d4..6bd1a9fc6 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -60,3 +60,15 @@ func (d *dentry) DecRef() { // inode.decRef(). d.inode.decRef() } + +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} + +// Watches implements vfs.DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) Watches() *vfs.Watches { + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 5ce82b793..f5f35a3bc 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -35,8 +35,8 @@ go_library( "fstree.go", "gofer.go", "handle.go", + "host_named_pipe.go", "p9file.go", - "pagemath.go", "regular_file.go", "socket.go", "special_file.go", @@ -48,6 +48,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fd", + "//pkg/fdnotifier", "//pkg/fspath", "//pkg/log", "//pkg/p9", @@ -72,6 +73,7 @@ go_library( "//pkg/unet", "//pkg/usermem", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 7f2181216..36e0e1856 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -760,7 +760,7 @@ afterTrailingSymlink: parent.dirMu.Unlock() return nil, syserror.EPERM } - fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts) + fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts, &ds) parent.dirMu.Unlock() return fd, err } @@ -873,19 +873,37 @@ func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts if opts.Flags&linux.O_DIRECT != 0 { return nil, syserror.EINVAL } - h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0) + // We assume that the server silently inserts O_NONBLOCK in the open flags + // for all named pipes (because all existing gofers do this). + // + // NOTE(b/133875563): This makes named pipe opens racy, because the + // mechanisms for translating nonblocking to blocking opens can only detect + // the instantaneous presence of a peer holding the other end of the pipe + // open, not whether the pipe was *previously* opened by a peer that has + // since closed its end. + isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0 +retry: + h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) if err != nil { + if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO { + // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails + // with ENXIO if opening the same named pipe with O_WRONLY would + // block because there are no readers of the pipe. + if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { + return nil, err + } + goto retry + } return nil, err } - seekable := d.fileType() == linux.S_IFREG - fd := &specialFileFD{ - handle: h, - seekable: seekable, + if isBlockingOpenOfNamedPipe && ats == vfs.MayRead && h.fd >= 0 { + if err := blockUntilNonblockingPipeHasWriter(ctx, h.fd); err != nil { + h.close(ctx) + return nil, err + } } - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ - DenyPRead: !seekable, - DenyPWrite: !seekable, - }); err != nil { + fd, err := newSpecialFileFD(h, mnt, d, opts.Flags) + if err != nil { h.close(ctx) return nil, err } @@ -894,7 +912,7 @@ func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts // Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked. // !d.isSynthetic(). -func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } @@ -947,6 +965,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } + *ds = appendDentry(*ds, child) // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { @@ -959,10 +978,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags) child.handleMu.Unlock() } - // Take a reference on the new dentry to be held by the new file - // description. (This reference also means that the new dentry is not - // eligible for caching yet, so we don't need to append to a dentry slice.) - child.refs = 1 // Insert the dentry into the tree. d.cacheNewChildLocked(child, name) if d.cachedMetadataAuthoritative() { @@ -981,22 +996,16 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } childVFSFD = &fd.vfsfd } else { - seekable := child.fileType() == linux.S_IFREG - fd := &specialFileFD{ - handle: handle{ - file: openFile, - fd: -1, - }, - seekable: seekable, + h := handle{ + file: openFile, + fd: -1, } if fdobj != nil { - fd.handle.fd = int32(fdobj.Release()) + h.fd = int32(fdobj.Release()) } - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{ - DenyPRead: !seekable, - DenyPWrite: !seekable, - }); err != nil { - fd.handle.close(ctx) + fd, err := newSpecialFileFD(h, mnt, child, opts.Flags) + if err != nil { + h.close(ctx) return nil, err } childVFSFD = &fd.vfsfd diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index ebf063a58..3f3bd56f0 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -84,12 +84,6 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 - // uid and gid are the effective KUID and KGID of the filesystem's creator, - // and are used as the owner and group for files that don't specify one. - // uid and gid are immutable. - uid auth.KUID - gid auth.KGID - // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -122,6 +116,8 @@ type filesystemOptions struct { fd int aname string interop InteropMode // derived from the "cache" mount option + dfltuid auth.KUID + dfltgid auth.KGID msize uint32 version string @@ -230,6 +226,15 @@ type InternalFilesystemOptions struct { OpenSocketsByConnecting bool } +// _V9FS_DEFUID and _V9FS_DEFGID (from Linux's fs/9p/v9fs.h) are the default +// UIDs and GIDs used for files that do not provide a specific owner or group +// respectively. +const ( + // uint32(-2) doesn't work in Go. + _V9FS_DEFUID = auth.KUID(4294967294) + _V9FS_DEFGID = auth.KGID(4294967294) +) + // Name implements vfs.FilesystemType.Name. func (FilesystemType) Name() string { return Name @@ -315,6 +320,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } + // Parse the default UID and GID. + fsopts.dfltuid = _V9FS_DEFUID + if dfltuidstr, ok := mopts["dfltuid"]; ok { + delete(mopts, "dfltuid") + dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr) + return nil, nil, syserror.EINVAL + } + // In Linux, dfltuid is interpreted as a UID and is converted to a KUID + // in the caller's user namespace, but goferfs isn't + // application-mountable. + fsopts.dfltuid = auth.KUID(dfltuid) + } + fsopts.dfltgid = _V9FS_DEFGID + if dfltgidstr, ok := mopts["dfltgid"]; ok { + delete(mopts, "dfltgid") + dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32) + if err != nil { + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr) + return nil, nil, syserror.EINVAL + } + fsopts.dfltgid = auth.KGID(dfltgid) + } + // Parse the 9P message size. fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M if msizestr, ok := mopts["msize"]; ok { @@ -422,8 +452,6 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt client: client, clock: ktime.RealtimeClockFromContext(ctx), devMinor: devMinor, - uid: creds.EffectiveKUID, - gid: creds.EffectiveKGID, syncableDentries: make(map[*dentry]struct{}), specialFileFDs: make(map[*specialFileFD]struct{}), } @@ -672,8 +700,8 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma file: file, ino: qid.Path, mode: uint32(attr.Mode), - uid: uint32(fs.uid), - gid: uint32(fs.gid), + uid: uint32(fs.opts.dfltuid), + gid: uint32(fs.opts.dfltgid), blockSize: usermem.PageSize, handle: handle{ fd: -1, @@ -928,8 +956,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin // so we can't race with Write or another truncate.) d.dataMu.Unlock() if d.size < oldSize { - oldpgend := pageRoundUp(oldSize) - newpgend := pageRoundUp(d.size) + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(d.size) if oldpgend != newpgend { d.mapsMu.Lock() d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ @@ -1011,6 +1039,18 @@ func (d *dentry) decRefLocked() { } } +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} + +// Watches implements vfs.DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) Watches() *vfs.Watches { + return nil +} + // checkCachingLocked should be called after d's reference count becomes 0 or it // becomes disowned. // diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go new file mode 100644 index 000000000..7294de7d6 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -0,0 +1,97 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "fmt" + "sync" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Global pipe used by blockUntilNonblockingPipeHasWriter since we can't create +// pipes after sentry initialization due to syscall filters. +var ( + tempPipeMu sync.Mutex + tempPipeReadFD int + tempPipeWriteFD int + tempPipeBuf [1]byte +) + +func init() { + var pipeFDs [2]int + if err := unix.Pipe(pipeFDs[:]); err != nil { + panic(fmt.Sprintf("failed to create pipe for gofer.blockUntilNonblockingPipeHasWriter: %v", err)) + } + tempPipeReadFD = pipeFDs[0] + tempPipeWriteFD = pipeFDs[1] +} + +func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error { + for { + ok, err := nonblockingPipeHasWriter(fd) + if err != nil { + return err + } + if ok { + return nil + } + if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { + return err + } + } +} + +func nonblockingPipeHasWriter(fd int32) (bool, error) { + tempPipeMu.Lock() + defer tempPipeMu.Unlock() + // Copy 1 byte from fd into the temporary pipe. + n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK) + if err == syserror.EAGAIN { + // The pipe represented by fd is empty, but has a writer. + return true, nil + } + if err != nil { + return false, err + } + if n == 0 { + // The pipe represented by fd is empty and has no writer. + return false, nil + } + // The pipe represented by fd is non-empty, so it either has, or has + // previously had, a writer. Remove the byte copied to the temporary pipe + // before returning. + if n, err := unix.Read(tempPipeReadFD, tempPipeBuf[:]); err != nil || n != 1 { + panic(fmt.Sprintf("failed to drain pipe for gofer.blockUntilNonblockingPipeHasWriter: got (%d, %v), wanted (1, nil)", n, err)) + } + return true, nil +} + +func sleepBetweenNamedPipeOpenChecks(ctx context.Context) error { + t := time.NewTimer(100 * time.Millisecond) + defer t.Stop() + cancel := ctx.SleepStart() + select { + case <-t.C: + ctx.SleepFinish(true) + return nil + case <-cancel: + ctx.SleepFinish(false) + return syserror.ErrInterrupted + } +} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 857f7c74e..0d10cf7ac 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -148,9 +148,9 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, err } // Remove touched pages from the cache. - pgstart := pageRoundDown(uint64(offset)) - pgend := pageRoundUp(uint64(offset + src.NumBytes())) - if pgend < pgstart { + pgstart := usermem.PageRoundDown(uint64(offset)) + pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes())) + if !ok { return 0, syserror.EINVAL } mr := memmap.MappableRange{pgstart, pgend} @@ -306,9 +306,10 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) if fillCache { // Read into the cache, then re-enter the loop to read from the // cache. + gapEnd, _ := usermem.PageRoundUp(gapMR.End) reqMR := memmap.MappableRange{ - Start: pageRoundDown(gapMR.Start), - End: pageRoundUp(gapMR.End), + Start: usermem.PageRoundDown(gapMR.Start), + End: gapEnd, } optMR := gap.Range() err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt) @@ -671,7 +672,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab // Constrain translations to d.size (rounded up) to prevent translation to // pages that may be concurrently truncated. - pgend := pageRoundUp(d.size) + pgend, _ := usermem.PageRoundUp(d.size) var beyondEOF bool if required.End > pgend { if required.Start >= pgend { @@ -818,43 +819,15 @@ type dentryPlatformFile struct { // IncRef implements platform.File.IncRef. func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { d.dataMu.Lock() - seg, gap := d.fdRefs.Find(fr.Start) - for { - switch { - case seg.Ok() && seg.Start() < fr.End: - seg = d.fdRefs.Isolate(seg, fr) - seg.SetValue(seg.Value() + 1) - seg, gap = seg.NextNonEmpty() - case gap.Ok() && gap.Start() < fr.End: - newRange := gap.Range().Intersect(fr) - usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped) - seg, gap = d.fdRefs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty() - default: - d.fdRefs.MergeAdjacent(fr) - d.dataMu.Unlock() - return - } - } + d.fdRefs.IncRefAndAccount(fr) + d.dataMu.Unlock() } // DecRef implements platform.File.DecRef. func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { d.dataMu.Lock() - seg := d.fdRefs.FindSegment(fr.Start) - - for seg.Ok() && seg.Start() < fr.End { - seg = d.fdRefs.Isolate(seg, fr) - if old := seg.Value(); old == 1 { - usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped) - seg = d.fdRefs.Remove(seg).NextSegment() - } else { - seg.SetValue(old - 1) - seg = seg.NextSegment() - } - } - d.fdRefs.MergeAdjacent(fr) + d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() - } // MapInternal implements platform.File.MapInternal. diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index a464e6a94..ff6126b87 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -19,17 +19,18 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" ) -// specialFileFD implements vfs.FileDescriptionImpl for files other than -// regular files, directories, and symlinks: pipes, sockets, etc. It is also -// used for regular files when filesystemOptions.specialRegularFiles is in -// effect. specialFileFD differs from regularFileFD by using per-FD handles -// instead of shared per-dentry handles, and never buffering I/O. +// specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device +// special files, and (when filesystemOptions.specialRegularFiles is in effect) +// regular files. specialFileFD differs from regularFileFD by using per-FD +// handles instead of shared per-dentry handles, and never buffering I/O. type specialFileFD struct { fileDescription @@ -40,13 +41,47 @@ type specialFileFD struct { // file offset is significant, i.e. a regular file. seekable is immutable. seekable bool + // mayBlock is true if this file description represents a file for which + // queue may send I/O readiness events. mayBlock is immutable. + mayBlock bool + queue waiter.Queue + // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex off int64 } +func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) { + ftype := d.fileType() + seekable := ftype == linux.S_IFREG + mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK + fd := &specialFileFD{ + handle: h, + seekable: seekable, + mayBlock: mayBlock, + } + if mayBlock && h.fd >= 0 { + if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { + return nil, err + } + } + if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ + DenyPRead: !seekable, + DenyPWrite: !seekable, + }); err != nil { + if mayBlock && h.fd >= 0 { + fdnotifier.RemoveFD(h.fd) + } + return nil, err + } + return fd, nil +} + // Release implements vfs.FileDescriptionImpl.Release. func (fd *specialFileFD) Release() { + if fd.mayBlock && fd.handle.fd >= 0 { + fdnotifier.RemoveFD(fd.handle.fd) + } fd.handle.close(context.Background()) fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) fs.syncMu.Lock() @@ -62,6 +97,32 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error { return fd.handle.file.flush(ctx) } +// Readiness implements waiter.Waitable.Readiness. +func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { + if fd.mayBlock { + return fdnotifier.NonBlockingPoll(fd.handle.fd, mask) + } + return fd.fileDescription.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + if fd.mayBlock { + fd.queue.EventRegister(e, mask) + return + } + fd.fileDescription.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { + if fd.mayBlock { + fd.queue.EventUnregister(e) + return + } + fd.fileDescription.EventUnregister(e) +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if fd.seekable && offset < 0 { @@ -81,6 +142,9 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs } buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) + if err == syserror.EAGAIN { + err = syserror.ErrWouldBlock + } if n == 0 { return 0, err } @@ -130,6 +194,9 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, err } n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) + if err == syserror.EAGAIN { + err = syserror.ErrWouldBlock + } return int64(n), err } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 2608e7e1d..1d5aa82dc 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -38,6 +38,9 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { // Preconditions: fs.interop != InteropModeShared. func (d *dentry) touchAtime(mnt *vfs.Mount) { + if mnt.Flags.NoATime { + return + } if err := mnt.CheckBeginWrite(); err != nil { return } diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 39509f703..ca0fe6d2b 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,6 +8,7 @@ go_library( "control.go", "host.go", "ioctl_unsafe.go", + "mmap.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -23,12 +24,15 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/arch", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", + "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 8caf55a1b..18b127521 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -86,15 +86,13 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) i := &inode{ hostFD: hostFD, - seekable: seekable, + ino: fs.NextIno(), isTTY: opts.IsTTY, - canMap: canMap(uint32(fileType)), wouldBlock: wouldBlock(uint32(fileType)), - ino: fs.NextIno(), - // For simplicity, set offset to 0. Technically, we should use the existing - // offset on the host if the file is seekable. - offset: 0, + seekable: seekable, + canMap: canMap(uint32(fileType)), } + i.pf.inode = i // Non-seekable files can't be memory mapped, assert this. if !i.seekable && i.canMap { @@ -117,6 +115,10 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // i.open will take a reference on d. defer d.DecRef() + + // For simplicity, fileDescription.offset is set to 0. Technically, we + // should only set to 0 on files that are not seekable (sockets, pipes, + // etc.), and use the offset from the host fd otherwise when importing. return i.open(ctx, d.VFSDentry(), mnt, flags) } @@ -189,11 +191,15 @@ type inode struct { // This field is initialized at creation time and is immutable. hostFD int - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. + // ino is an inode number unique within this filesystem. // // This field is initialized at creation time and is immutable. - wouldBlock bool + ino uint64 + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool // seekable is false if the host fd points to a file representing a stream, // e.g. a socket or a pipe. Such files are not seekable and can return @@ -202,29 +208,29 @@ type inode struct { // This field is initialized at creation time and is immutable. seekable bool - // isTTY is true if this file represents a TTY. + // wouldBlock is true if the host FD would return EWOULDBLOCK for + // operations that would block. // // This field is initialized at creation time and is immutable. - isTTY bool + wouldBlock bool + + // Event queue for blocking operations. + queue waiter.Queue // canMap specifies whether we allow the file to be memory mapped. // // This field is initialized at creation time and is immutable. canMap bool - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 + // mapsMu protects mappings. + mapsMu sync.Mutex - // offsetMu protects offset. - offsetMu sync.Mutex - - // offset specifies the current file offset. - offset int64 + // If canMap is true, mappings tracks mappings of hostFD into + // memmap.MappingSpaces. + mappings memmap.MappingSet - // Event queue for blocking operations. - queue waiter.Queue + // pf implements platform.File for mappings of hostFD. + pf inodePlatformFile } // CheckPermissions implements kernfs.Inode. @@ -388,6 +394,21 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err } + oldSize := uint64(hostStat.Size) + if s.Size < oldSize { + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(s.Size) + if oldpgend != newpgend { + i.mapsMu.Lock() + i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + i.mapsMu.Unlock() + } + } } if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 { ts := [2]syscall.Timespec{ @@ -464,9 +485,6 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u return vfsfd, nil } - // For simplicity, set offset to 0. Technically, we should - // only set to 0 on files that are not seekable (sockets, pipes, etc.), - // and use the offset from the host fd otherwise. fd := &fileDescription{inode: i} vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { @@ -487,6 +505,13 @@ type fileDescription struct { // // inode is immutable after fileDescription creation. inode *inode + + // offsetMu protects offset. + offsetMu sync.Mutex + + // offset specifies the current file offset. It is only meaningful when + // inode.seekable is true. + offset int64 } // SetStat implements vfs.FileDescriptionImpl. @@ -532,10 +557,10 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. - i.offsetMu.Lock() - n, err := readFromHostFD(ctx, i.hostFD, dst, i.offset, opts.Flags) - i.offset += n - i.offsetMu.Unlock() + f.offsetMu.Lock() + n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags) + f.offset += n + f.offsetMu.Unlock() return n, err } @@ -572,10 +597,10 @@ func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opt } // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file. - i.offsetMu.Lock() - n, err := writeToHostFD(ctx, i.hostFD, src, i.offset, opts.Flags) - i.offset += n - i.offsetMu.Unlock() + f.offsetMu.Lock() + n, err := writeToHostFD(ctx, i.hostFD, src, f.offset, opts.Flags) + f.offset += n + f.offsetMu.Unlock() return n, err } @@ -600,41 +625,41 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return 0, syserror.ESPIPE } - i.offsetMu.Lock() - defer i.offsetMu.Unlock() + f.offsetMu.Lock() + defer f.offsetMu.Unlock() switch whence { case linux.SEEK_SET: if offset < 0 { - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - i.offset = offset + f.offset = offset case linux.SEEK_CUR: - // Check for overflow. Note that underflow cannot occur, since i.offset >= 0. - if offset > math.MaxInt64-i.offset { - return i.offset, syserror.EOVERFLOW + // Check for overflow. Note that underflow cannot occur, since f.offset >= 0. + if offset > math.MaxInt64-f.offset { + return f.offset, syserror.EOVERFLOW } - if i.offset+offset < 0 { - return i.offset, syserror.EINVAL + if f.offset+offset < 0 { + return f.offset, syserror.EINVAL } - i.offset += offset + f.offset += offset case linux.SEEK_END: var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { - return i.offset, err + return f.offset, err } size := s.Size // Check for overflow. Note that underflow cannot occur, since size >= 0. if offset > math.MaxInt64-size { - return i.offset, syserror.EOVERFLOW + return f.offset, syserror.EOVERFLOW } if size+offset < 0 { - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - i.offset = size + offset + f.offset = size + offset case linux.SEEK_DATA, linux.SEEK_HOLE: // Modifying the offset in the host file table should not matter, since @@ -643,16 +668,16 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i // For reading and writing, we always rely on our internal offset. n, err := unix.Seek(i.hostFD, offset, int(whence)) if err != nil { - return i.offset, err + return f.offset, err } - i.offset = n + f.offset = n default: // Invalid whence. - return i.offset, syserror.EINVAL + return f.offset, syserror.EINVAL } - return i.offset, nil + return f.offset, nil } // Sync implements FileDescriptionImpl. @@ -666,8 +691,9 @@ func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts if !f.inode.canMap { return syserror.ENODEV } - // TODO(gvisor.dev/issue/1672): Implement ConfigureMMap and Mappable interface. - return syserror.ENODEV + i := f.inode + i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) + return vfs.GenericConfigureMMap(&f.vfsfd, i, opts) } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go new file mode 100644 index 000000000..8545a82f0 --- /dev/null +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" +) + +// inodePlatformFile implements platform.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// +// inodePlatformFile should only be used if inode.canMap is true. +type inodePlatformFile struct { + *inode + + // fdRefsMu protects fdRefs. + fdRefsMu sync.Mutex + + // fdRefs counts references on platform.File offsets. It is used solely for + // memory accounting. + fdRefs fsutil.FrameRefSet + + // fileMapper caches mappings of the host file represented by this inode. + fileMapper fsutil.HostFileMapper + + // fileMapperInitOnce is used to lazily initialize fileMapper. + fileMapperInitOnce sync.Once +} + +// IncRef implements platform.File.IncRef. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) IncRef(fr platform.FileRange) { + i.fdRefsMu.Lock() + i.fdRefs.IncRefAndAccount(fr) + i.fdRefsMu.Unlock() +} + +// DecRef implements platform.File.DecRef. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) DecRef(fr platform.FileRange) { + i.fdRefsMu.Lock() + i.fdRefs.DecRefAndAccount(fr) + i.fdRefsMu.Unlock() +} + +// MapInternal implements platform.File.MapInternal. +// +// Precondition: i.inode.canMap must be true. +func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { + return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) +} + +// FD implements platform.File.FD. +func (i *inodePlatformFile) FD() int { + return i.hostFD +} + +// AddMapping implements memmap.Mappable.AddMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + i.mapsMu.Lock() + mapped := i.mappings.AddMapping(ms, ar, offset, writable) + for _, r := range mapped { + i.pf.fileMapper.IncRefOn(r) + } + i.mapsMu.Unlock() + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + i.mapsMu.Lock() + unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable) + for _, r := range unmapped { + i.pf.fileMapper.DecRefOn(r) + } + i.mapsMu.Unlock() +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + return i.AddMapping(ctx, ms, dstAR, offset, writable) +} + +// Translate implements memmap.Mappable.Translate. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + mr := optional + return []memmap.Translation{ + { + Source: mr, + File: &i.pf, + Offset: mr.Start, + Perms: usermem.AnyAccess, + }, + }, nil +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +// +// Precondition: i.inode.canMap must be true. +func (i *inode) InvalidateUnsavable(ctx context.Context) error { + // We expect the same host fd across save/restore, so all translations + // should be valid. + return nil +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index a83151ad3..bbee8ccda 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -225,9 +225,21 @@ func (d *Dentry) destroy() { } } +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} + +// Watches implements vfs.DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *Dentry) Watches() *vfs.Watches { + return nil +} + // InsertChild inserts child into the vfs dentry cache with the given name under // this dentry. This does not update the directory inode, so calling this on -// it's own isn't sufficient to insert a child into a directory. InsertChild +// its own isn't sufficient to insert a child into a directory. InsertChild // updates the link count on d if required. // // Precondition: d must represent a directory inode. diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 007be1572..062321cbc 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -59,6 +59,7 @@ go_library( "//pkg/sentry/pgalloc", "//pkg/sentry/platform", "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/uniqueid", "//pkg/sentry/usage", "//pkg/sentry/vfs", "//pkg/sentry/vfs/lock", diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go index 83bf885ee..ac54d420d 100644 --- a/pkg/sentry/fsimpl/tmpfs/device_file.go +++ b/pkg/sentry/fsimpl/tmpfs/device_file.go @@ -29,7 +29,7 @@ type deviceFile struct { minor uint32 } -func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode { +func (fs *filesystem) newDeviceFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode { file := &deviceFile{ kind: kind, major: major, @@ -43,7 +43,7 @@ func (fs *filesystem) newDeviceFile(creds *auth.Credentials, mode linux.FileMode default: panic(fmt.Sprintf("invalid DeviceKind: %v", kind)) } - file.inode.init(file, fs, creds, mode) + file.inode.init(file, fs, kuid, kgid, mode) file.inode.nlink = 1 // from parent directory return &file.inode } diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index f2399981b..913b8a6c5 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -48,9 +48,9 @@ type directory struct { childList dentryList } -func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *directory { +func (fs *filesystem) newDirectory(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *directory { dir := &directory{} - dir.inode.init(dir, fs, creds, linux.S_IFDIR|mode) + dir.inode.init(dir, fs, kuid, kgid, linux.S_IFDIR|mode) dir.inode.nlink = 2 // from "." and parent directory or ".." for root dir.dentry.inode = &dir.inode dir.dentry.vfsd.Init(&dir.dentry) @@ -79,6 +79,7 @@ func (dir *directory) removeChildLocked(child *dentry) { dir.iterMu.Lock() dir.childList.Remove(child) dir.iterMu.Unlock() + child.unlinked = true } type directoryFD struct { @@ -112,6 +113,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba dir.iterMu.Lock() defer dir.iterMu.Unlock() + fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) fd.inode().touchAtime(fd.vfsfd.Mount()) if fd.off == 0 { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 80fa7b29d..e801680e8 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -177,6 +177,12 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if err := create(parentDir, name); err != nil { return err } + + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent) parentDir.inode.touchCMtime() return nil } @@ -241,6 +247,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EMLINK } d.inode.incLinksLocked() + d.inode.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent) parentDir.insertChildLocked(fs.newDentry(d.inode), name) return nil }) @@ -249,11 +256,12 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { return fs.doCreateAt(rp, true /* dir */, func(parentDir *directory, name string) error { + creds := rp.Credentials() if parentDir.inode.nlink == maxLinks { return syserror.EMLINK } parentDir.inode.incLinksLocked() // from child's ".." - childDir := fs.newDirectory(rp.Credentials(), opts.Mode) + childDir := fs.newDirectory(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) parentDir.insertChildLocked(&childDir.dentry, name) return nil }) @@ -262,18 +270,19 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { + creds := rp.Credentials() var childInode *inode switch opts.Mode.FileType() { case 0, linux.S_IFREG: - childInode = fs.newRegularFile(rp.Credentials(), opts.Mode) + childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) case linux.S_IFIFO: - childInode = fs.newNamedPipe(rp.Credentials(), opts.Mode) + childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) case linux.S_IFBLK: - childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor) + childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor) case linux.S_IFCHR: - childInode = fs.newDeviceFile(rp.Credentials(), opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor) + childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor) case linux.S_IFSOCK: - childInode = fs.newSocketFile(rp.Credentials(), opts.Mode, opts.Endpoint) + childInode = fs.newSocketFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, opts.Endpoint) default: return syserror.EINVAL } @@ -348,12 +357,14 @@ afterTrailingSymlink: } defer rp.Mount().EndWrite() // Create and open the child. - child := fs.newDentry(fs.newRegularFile(rp.Credentials(), opts.Mode)) + creds := rp.Credentials() + child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)) parentDir.insertChildLocked(child, name) fd, err := child.open(ctx, rp, &opts, true) if err != nil { return nil, err } + parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent) parentDir.inode.touchCMtime() return fd, nil } @@ -559,6 +570,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa newParentDir.inode.touchCMtime() } renamed.inode.touchCtime() + + vfs.InotifyRename(ctx, &renamed.inode.watches, &oldParentDir.inode.watches, &newParentDir.inode.watches, oldName, newName, renamed.inode.isDir()) return nil } @@ -603,8 +616,11 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } parentDir.removeChildLocked(child) - parentDir.inode.decLinksLocked() // from child's ".." + parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent) + // Remove links for child, child/., and child/.. child.inode.decLinksLocked() + child.inode.decLinksLocked() + parentDir.inode.decLinksLocked() vfsObj.CommitDeleteDentry(&child.vfsd) parentDir.inode.touchCMtime() return nil @@ -618,7 +634,14 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts if err != nil { return err } - return d.inode.setStat(ctx, rp.Credentials(), &opts.Stat) + if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil { + return err + } + + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // StatAt implements vfs.FilesystemImpl.StatAt. @@ -656,7 +679,8 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { return fs.doCreateAt(rp, false /* dir */, func(parentDir *directory, name string) error { - child := fs.newDentry(fs.newSymlink(rp.Credentials(), target)) + creds := rp.Credentials() + child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target)) parentDir.insertChildLocked(child, name) return nil }) @@ -698,6 +722,12 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } + + // Generate inotify events. Note that this must take place before the link + // count of the child is decremented, or else the watches may be dropped + // before these events are added. + vfs.InotifyRemoveChild(&child.inode.watches, &parentDir.inode.watches, name) + parentDir.removeChildLocked(child) child.inode.decLinksLocked() vfsObj.CommitDeleteDentry(&child.vfsd) @@ -754,7 +784,12 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt if err != nil { return err } - return d.inode.setxattr(rp.Credentials(), &opts) + if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil { + return err + } + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. @@ -765,7 +800,12 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, if err != nil { return err } - return d.inode.removexattr(rp.Credentials(), name) + if err := d.inode.removexattr(rp.Credentials(), name); err != nil { + return err + } + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index 8d77b3fa8..739350cf0 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -30,9 +30,9 @@ type namedPipe struct { // Preconditions: // * fs.mu must be locked. // * rp.Mount().CheckBeginWrite() has been called successfully. -func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode { +func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} - file.inode.init(file, fs, creds, linux.S_IFIFO|mode) + file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode) file.inode.nlink = 1 // Only the parent has a link. return &file.inode } diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 3f433d666..4f2ae04d2 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -85,12 +85,12 @@ type regularFile struct { size uint64 } -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { +func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ memFile: fs.memFile, seals: linux.F_SEAL_SEAL, } - file.inode.init(file, fs, creds, linux.S_IFREG|mode) + file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode) file.inode.nlink = 1 // from parent directory return &file.inode } @@ -312,7 +312,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off f := fd.inode().impl.(*regularFile) if end := offset + srclen; end < offset { // Overflow. - return 0, syserror.EFBIG + return 0, syserror.EINVAL } var err error diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go index 25c2321af..3ed650474 100644 --- a/pkg/sentry/fsimpl/tmpfs/socket_file.go +++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go @@ -26,9 +26,9 @@ type socketFile struct { ep transport.BoundEndpoint } -func (fs *filesystem) newSocketFile(creds *auth.Credentials, mode linux.FileMode, ep transport.BoundEndpoint) *inode { +func (fs *filesystem) newSocketFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, ep transport.BoundEndpoint) *inode { file := &socketFile{ep: ep} - file.inode.init(file, fs, creds, mode) + file.inode.init(file, fs, kuid, kgid, mode) file.inode.nlink = 1 // from parent directory return &file.inode } diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go index 47e075ed4..b0de5fabe 100644 --- a/pkg/sentry/fsimpl/tmpfs/symlink.go +++ b/pkg/sentry/fsimpl/tmpfs/symlink.go @@ -24,11 +24,11 @@ type symlink struct { target string // immutable } -func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode { +func (fs *filesystem) newSymlink(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, target string) *inode { link := &symlink{ target: target, } - link.inode.init(link, fs, creds, linux.S_IFLNK|0777) + link.inode.init(link, fs, kuid, kgid, linux.S_IFLNK|mode) link.inode.nlink = 1 // from parent directory return &link.inode } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 1e781aecd..7ce1b86c7 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -30,6 +30,7 @@ package tmpfs import ( "fmt" "math" + "strconv" "strings" "sync/atomic" @@ -112,6 +113,58 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } + mopts := vfs.GenericParseMountOptions(opts.Data) + rootMode := linux.FileMode(0777) + if rootFileType == linux.S_IFDIR { + rootMode = 01777 + } + modeStr, ok := mopts["mode"] + if ok { + delete(mopts, "mode") + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil { + ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid mode: %q", modeStr) + return nil, nil, syserror.EINVAL + } + rootMode = linux.FileMode(mode & 07777) + } + rootKUID := creds.EffectiveKUID + uidStr, ok := mopts["uid"] + if ok { + delete(mopts, "uid") + uid, err := strconv.ParseUint(uidStr, 10, 32) + if err != nil { + ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid uid: %q", uidStr) + return nil, nil, syserror.EINVAL + } + kuid := creds.UserNamespace.MapToKUID(auth.UID(uid)) + if !kuid.Ok() { + ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped uid: %d", uid) + return nil, nil, syserror.EINVAL + } + rootKUID = kuid + } + rootKGID := creds.EffectiveKGID + gidStr, ok := mopts["gid"] + if ok { + delete(mopts, "gid") + gid, err := strconv.ParseUint(gidStr, 10, 32) + if err != nil { + ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid gid: %q", gidStr) + return nil, nil, syserror.EINVAL + } + kgid := creds.UserNamespace.MapToKGID(auth.GID(gid)) + if !kgid.Ok() { + ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped gid: %d", gid) + return nil, nil, syserror.EINVAL + } + rootKGID = kgid + } + if len(mopts) != 0 { + ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err @@ -127,11 +180,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt var root *dentry switch rootFileType { case linux.S_IFREG: - root = fs.newDentry(fs.newRegularFile(creds, 0777)) + root = fs.newDentry(fs.newRegularFile(rootKUID, rootKGID, rootMode)) case linux.S_IFLNK: - root = fs.newDentry(fs.newSymlink(creds, tmpfsOpts.RootSymlinkTarget)) + root = fs.newDentry(fs.newSymlink(rootKUID, rootKGID, rootMode, tmpfsOpts.RootSymlinkTarget)) case linux.S_IFDIR: - root = &fs.newDirectory(creds, 01777).dentry + root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry default: fs.vfsfs.DecRef() return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType) @@ -163,6 +216,11 @@ type dentry struct { // filesystem.mu. name string + // unlinked indicates whether this dentry has been unlinked from its parent. + // It is only set to true on an unlink operation, and never set from true to + // false. unlinked is protected by filesystem.mu. + unlinked bool + // dentryEntry (ugh) links dentries into their parent directory.childList. dentryEntry @@ -201,6 +259,26 @@ func (d *dentry) DecRef() { d.inode.decRef() } +// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. +func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) { + if d.inode.isDir() { + events |= linux.IN_ISDIR + } + + // The ordering below is important, Linux always notifies the parent first. + if d.parent != nil { + // Note that d.parent or d.name may be stale if there is a concurrent + // rename operation. Inotify does not provide consistency guarantees. + d.parent.inode.watches.NotifyWithExclusions(d.name, events, cookie, et, d.unlinked) + } + d.inode.watches.Notify("", events, cookie, et) +} + +// Watches implements vfs.DentryImpl.Watches. +func (d *dentry) Watches() *vfs.Watches { + return &d.inode.watches +} + // inode represents a filesystem object. type inode struct { // fs is the owning filesystem. fs is immutable. @@ -209,11 +287,9 @@ type inode struct { // refs is a reference count. refs is accessed using atomic memory // operations. // - // A reference is held on all inodes that are reachable in the filesystem - // tree. For non-directories (which may have multiple hard links), this - // means that a reference is dropped when nlink reaches 0. For directories, - // nlink never reaches 0 due to the "." entry; instead, - // filesystem.RmdirAt() drops the reference. + // A reference is held on all inodes as long as they are reachable in the + // filesystem tree, i.e. nlink is nonzero. This reference is dropped when + // nlink reaches 0. refs int64 // xattrs implements extended attributes. @@ -238,20 +314,23 @@ type inode struct { // Advisory file locks, which lock at the inode level. locks lock.FileLocks + // Inotify watches for this inode. + watches vfs.Watches + impl interface{} // immutable } const maxLinks = math.MaxUint32 -func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { +func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) { if mode.FileType() == 0 { panic("file type is required in FileMode") } i.fs = fs i.refs = 1 i.mode = uint32(mode) - i.uid = uint32(creds.EffectiveKUID) - i.gid = uint32(creds.EffectiveKGID) + i.uid = uint32(kuid) + i.gid = uint32(kgid) i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1) // Tmpfs creation sets atime, ctime, and mtime to current time. now := fs.clock.Now().Nanoseconds() @@ -259,6 +338,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, i.ctime = now i.mtime = now // i.nlink initialized by caller + i.watches = vfs.Watches{} i.impl = impl } @@ -276,14 +356,17 @@ func (i *inode) incLinksLocked() { atomic.AddUint32(&i.nlink, 1) } -// decLinksLocked decrements i's link count. +// decLinksLocked decrements i's link count. If the link count reaches 0, we +// remove a reference on i as well. // // Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. func (i *inode) decLinksLocked() { if i.nlink == 0 { panic("tmpfs.inode.decLinksLocked() called with no existing links") } - atomic.AddUint32(&i.nlink, ^uint32(0)) + if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 { + i.decRef() + } } func (i *inode) incRef() { @@ -306,6 +389,7 @@ func (i *inode) tryIncRef() bool { func (i *inode) decRef() { if refs := atomic.AddInt64(&i.refs, -1); refs == 0 { + i.watches.HandleDeletion() if regFile, ok := i.impl.(*regularFile); ok { // Release memory used by regFile to store data. Since regFile is // no longer usable, we don't need to grab any locks or update any @@ -531,6 +615,9 @@ func (i *inode) isDir() bool { } func (i *inode) touchAtime(mnt *vfs.Mount) { + if mnt.Flags.NoATime { + return + } if err := mnt.CheckBeginWrite(); err != nil { return } @@ -627,8 +714,12 @@ func (fd *fileDescription) filesystem() *filesystem { return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } +func (fd *fileDescription) dentry() *dentry { + return fd.vfsfd.Dentry().Impl().(*dentry) +} + func (fd *fileDescription) inode() *inode { - return fd.vfsfd.Dentry().Impl().(*dentry).inode + return fd.dentry().inode } // Stat implements vfs.FileDescriptionImpl.Stat. @@ -641,7 +732,15 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - return fd.inode().setStat(ctx, creds, &opts.Stat) + d := fd.dentry() + if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil { + return err + } + + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // Listxattr implements vfs.FileDescriptionImpl.Listxattr. @@ -656,12 +755,26 @@ func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOption // Setxattr implements vfs.FileDescriptionImpl.Setxattr. func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { - return fd.inode().setxattr(auth.CredentialsFromContext(ctx), &opts) + d := fd.dentry() + if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil { + return err + } + + // Generate inotify events. + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // Removexattr implements vfs.FileDescriptionImpl.Removexattr. func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { - return fd.inode().removexattr(auth.CredentialsFromContext(ctx), name) + d := fd.dentry() + if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil { + return err + } + + // Generate inotify events. + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // NewMemfd creates a new tmpfs regular file and file description that can back @@ -674,8 +787,7 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with // S_IRWXUGO. - mode := linux.FileMode(0777) - inode := fs.newRegularFile(creds, mode) + inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777) rf := inode.impl.(*regularFile) if allowSeals { rf.seals = 0 diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go index e057d2c6d..6862f2ef5 100644 --- a/pkg/sentry/kernel/auth/credentials.go +++ b/pkg/sentry/kernel/auth/credentials.go @@ -232,3 +232,31 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) { } return NoID, syserror.EPERM } + +// SetUID translates the provided uid to the root user namespace and updates c's +// uids to it. This performs no permissions or capabilities checks, the caller +// is responsible for ensuring the calling context is permitted to modify c. +func (c *Credentials) SetUID(uid UID) error { + kuid := c.UserNamespace.MapToKUID(uid) + if !kuid.Ok() { + return syserror.EINVAL + } + c.RealKUID = kuid + c.EffectiveKUID = kuid + c.SavedKUID = kuid + return nil +} + +// SetGID translates the provided gid to the root user namespace and updates c's +// gids to it. This performs no permissions or capabilities checks, the caller +// is responsible for ensuring the calling context is permitted to modify c. +func (c *Credentials) SetGID(gid GID) error { + kgid := c.UserNamespace.MapToKGID(gid) + if !kgid.Ok() { + return syserror.EINVAL + } + c.RealKGID = kgid + c.EffectiveKGID = kgid + c.SavedKGID = kgid + return nil +} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index ed40b5303..dbfcef0fa 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -152,7 +152,13 @@ func (f *FDTable) drop(file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(file *vfs.FileDescription) { // TODO(gvisor.dev/issue/1480): Release locks. - // TODO(gvisor.dev/issue/1479): Send inotify events. + + // Generate inotify events. + ev := uint32(linux.IN_CLOSE_NOWRITE) + if file.IsWritable() { + ev = linux.IN_CLOSE_WRITE + } + file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent) // Drop the table reference. file.DecRef() diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index f29dc0472..7bfa9075a 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -8,6 +8,7 @@ go_library( "device.go", "node.go", "pipe.go", + "pipe_unsafe.go", "pipe_util.go", "reader.go", "reader_writer.go", @@ -20,6 +21,7 @@ go_library( "//pkg/amutex", "//pkg/buffer", "//pkg/context", + "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 62c8691f1..79645d7d2 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -207,7 +207,10 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() + return p.readLocked(ctx, ops) +} +func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { // Is the pipe empty? if p.view.Size() == 0 { if !p.HasWriters() { @@ -246,7 +249,10 @@ type writeOps struct { func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() + return p.writeLocked(ctx, ops) +} +func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { // Can't write to a pipe with no readers. if !p.HasReaders() { return 0, syscall.EPIPE diff --git a/pkg/sentry/fsimpl/gofer/pagemath.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go index 847cb0784..dd60cba24 100644 --- a/pkg/sentry/fsimpl/gofer/pagemath.go +++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go @@ -12,20 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -package gofer +package pipe import ( - "gvisor.dev/gvisor/pkg/usermem" + "unsafe" ) -// This are equivalent to usermem.Addr.RoundDown/Up, but without the -// potentially truncating conversion to usermem.Addr. This is necessary because -// there is no way to define generic "PageRoundDown/Up" functions in Go. - -func pageRoundDown(x uint64) uint64 { - return x &^ (usermem.PageSize - 1) -} - -func pageRoundUp(x uint64) uint64 { - return pageRoundDown(x + usermem.PageSize - 1) +// lockTwoPipes locks both x.mu and y.mu in an order that is guaranteed to be +// consistent for both lockTwoPipes(x, y) and lockTwoPipes(y, x), such that +// concurrent calls cannot deadlock. +// +// Preconditions: x != y. +func lockTwoPipes(x, y *Pipe) { + // Lock the two pipes in order of increasing address. + if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) { + x.mu.Lock() + y.mu.Lock() + } else { + y.mu.Lock() + x.mu.Lock() + } } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index b54f08a30..2602bed72 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -16,7 +16,9 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +152,9 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) * return &fd.vfsfd } -// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. +// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements +// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to +// other FileDescriptions for splice(2) and tee(2). type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -229,3 +233,216 @@ func (fd *VFSPipeFD) PipeSize() int64 { func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { return fd.pipe.SetFifoSize(size) } + +// IOSequence returns a useremm.IOSequence that reads up to count bytes from, +// or writes up to count bytes to, fd. +func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence { + return usermem.IOSequence{ + IO: fd, + Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), + } +} + +// CopyIn implements usermem.IO.CopyIn. +func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { + origCount := int64(len(dst)) + n, err := fd.pipe.read(ctx, readOps{ + left: func() int64 { + return int64(len(dst)) + }, + limit: func(l int64) { + dst = dst[:l] + }, + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadAt(dst, 0) + view.TrimFront(int64(n)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventOut) + } + if err == nil && n != origCount { + return int(n), syserror.ErrWouldBlock + } + return int(n), err +} + +// CopyOut implements usermem.IO.CopyOut. +func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { + origCount := int64(len(src)) + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return int64(len(src)) + }, + limit: func(l int64) { + src = src[:l] + }, + write: func(view *buffer.View) (int64, error) { + view.Append(src) + return int64(len(src)), nil + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return int(n), syserror.ErrWouldBlock + } + return int(n), err +} + +// ZeroOut implements usermem.IO.ZeroOut. +func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { + origCount := toZero + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return toZero + }, + limit: func(l int64) { + toZero = l + }, + write: func(view *buffer.View) (int64, error) { + view.Grow(view.Size()+toZero, true /* zero */) + return toZero, nil + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// CopyInTo implements usermem.IO.CopyInTo. +func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { + count := ars.NumBytes() + if count == 0 { + return 0, nil + } + origCount := count + n, err := fd.pipe.read(ctx, readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadToSafememWriter(dst, uint64(count)) + view.TrimFront(int64(n)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventOut) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// CopyOutFrom implements usermem.IO.CopyOutFrom. +func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { + count := ars.NumBytes() + if count == 0 { + return 0, nil + } + origCount := count + n, err := fd.pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(view *buffer.View) (int64, error) { + n, err := view.WriteFromSafememReader(src, uint64(count)) + return int64(n), err + }, + }) + if n > 0 { + fd.pipe.Notify(waiter.EventIn) + } + if err == nil && n != origCount { + return n, syserror.ErrWouldBlock + } + return n, err +} + +// SwapUint32 implements usermem.IO.SwapUint32. +func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) { + // How did a pipe get passed as the virtual address space to futex(2)? + panic("VFSPipeFD.SwapUint32 called unexpectedly") +} + +// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32. +func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) { + panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly") +} + +// LoadUint32 implements usermem.IO.LoadUint32. +func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) { + panic("VFSPipeFD.LoadUint32 called unexpectedly") +} + +// Splice reads up to count bytes from src and writes them to dst. It returns +// the number of bytes moved. +// +// Preconditions: count > 0. +func Splice(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) { + return spliceOrTee(ctx, dst, src, count, true /* removeFromSrc */) +} + +// Tee reads up to count bytes from src and writes them to dst, without +// removing the read bytes from src. It returns the number of bytes copied. +// +// Preconditions: count > 0. +func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) { + return spliceOrTee(ctx, dst, src, count, false /* removeFromSrc */) +} + +// Preconditions: count > 0. +func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) { + if dst.pipe == src.pipe { + return 0, syserror.EINVAL + } + + lockTwoPipes(dst.pipe, src.pipe) + defer dst.pipe.mu.Unlock() + defer src.pipe.mu.Unlock() + + n, err := dst.pipe.writeLocked(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(dstView *buffer.View) (int64, error) { + return src.pipe.readLocked(ctx, readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(srcView *buffer.View) (int64, error) { + n, err := srcView.ReadToSafememWriter(dstView, uint64(count)) + if n > 0 && removeFromSrc { + srcView.TrimFront(int64(n)) + } + return int64(n), err + }, + }) + }, + }) + if n > 0 { + dst.pipe.Notify(waiter.EventIn) + src.pipe.Notify(waiter.EventOut) + } + return n, err +} diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 00c425cca..9b69f3cbe 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -198,6 +198,10 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{}) t.tg.pidns.owner.mu.Unlock() + oldFDTable := t.fdTable + t.fdTable = t.fdTable.Fork() + oldFDTable.DecRef() + // Remove FDs with the CloseOnExec flag set. t.fdTable.RemoveIf(func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool { return flags.CloseOnExec diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 73591dab7..a036ce53c 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -25,6 +25,7 @@ go_template_instance( out = "vma_set.go", consts = { "minDegree": "8", + "trackGaps": "1", }, imports = { "usermem": "gvisor.dev/gvisor/pkg/usermem", diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 9a14e69e6..16d8207e9 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -195,7 +195,7 @@ func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange { // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextGap() { + for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift up to match the alignment? if offset := uint64(gr.Start) % alignment; offset != 0 { @@ -214,7 +214,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou // Preconditions: mm.mappingMu must be locked. func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) { - for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevGap() { + for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) { if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length { // Can we shift down to match the alignment? start := gr.End - usermem.Addr(length) diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 1eeb9f317..a9836ba71 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -33,6 +33,7 @@ go_template_instance( out = "usage_set.go", consts = { "minDegree": "10", + "trackGaps": "1", }, imports = { "platform": "gvisor.dev/gvisor/pkg/sentry/platform", @@ -48,6 +49,26 @@ go_template_instance( }, ) +go_template_instance( + name = "reclaim_set", + out = "reclaim_set.go", + consts = { + "minDegree": "10", + }, + imports = { + "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + }, + package = "pgalloc", + prefix = "reclaim", + template = "//pkg/segment:generic_set", + types = { + "Key": "uint64", + "Range": "platform.FileRange", + "Value": "reclaimSetValue", + "Functions": "reclaimSetFunctions", + }, +) + go_library( name = "pgalloc", srcs = [ @@ -56,6 +77,7 @@ go_library( "evictable_range_set.go", "pgalloc.go", "pgalloc_unsafe.go", + "reclaim_set.go", "save_restore.go", "usage_set.go", ], diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 2b11ea4ae..46f19d218 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -108,12 +108,6 @@ type MemoryFile struct { usageSwapped uint64 usageLast time.Time - // minUnallocatedPage is the minimum page that may be unallocated. - // i.e., there are no unallocated pages below minUnallocatedPage. - // - // minUnallocatedPage is protected by mu. - minUnallocatedPage uint64 - // fileSize is the size of the backing memory file in bytes. fileSize is // always a power-of-two multiple of chunkSize. // @@ -146,11 +140,9 @@ type MemoryFile struct { // is protected by mu. reclaimable bool - // minReclaimablePage is the minimum page that may be reclaimable. - // i.e., all reclaimable pages are >= minReclaimablePage. - // - // minReclaimablePage is protected by mu. - minReclaimablePage uint64 + // relcaim is the collection of regions for reclaim. relcaim is protected + // by mu. + reclaim reclaimSet // reclaimCond is signaled (with mu locked) when reclaimable or destroyed // transitions from false to true. @@ -273,12 +265,10 @@ type evictableMemoryUserInfo struct { } const ( - chunkShift = 24 - chunkSize = 1 << chunkShift // 16 MB + chunkShift = 30 + chunkSize = 1 << chunkShift // 1 GB chunkMask = chunkSize - 1 - initialSize = chunkSize - // maxPage is the highest 64-bit page. maxPage = math.MaxUint64 &^ (usermem.PageSize - 1) ) @@ -302,19 +292,12 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) { if err := file.Truncate(0); err != nil { return nil, err } - if err := file.Truncate(initialSize); err != nil { - return nil, err - } f := &MemoryFile{ - opts: opts, - fileSize: initialSize, - file: file, - // No pages are reclaimable. DecRef will always be able to - // decrease minReclaimablePage from this point. - minReclaimablePage: maxPage, - evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo), + opts: opts, + file: file, + evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo), } - f.mappings.Store(make([]uintptr, initialSize/chunkSize)) + f.mappings.Store(make([]uintptr, 0)) f.reclaimCond.L = &f.mu if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure { @@ -404,39 +387,29 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi alignment = usermem.HugePageSize } - start, minUnallocatedPage := findUnallocatedRange(&f.usage, f.minUnallocatedPage, length, alignment) - end := start + length - // File offsets are int64s. Since length must be strictly positive, end - // cannot legitimately be 0. - if end < start || int64(end) <= 0 { + // Find a range in the underlying file. + fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) + if !ok { return platform.FileRange{}, syserror.ENOMEM } - // Expand the file if needed. Double the file size on each expansion; - // uncommitted pages have effectively no cost. - fileSize := f.fileSize - for int64(end) > fileSize { - if fileSize >= 2*fileSize { - // fileSize overflow. - return platform.FileRange{}, syserror.ENOMEM - } - fileSize *= 2 - } - if fileSize > f.fileSize { - if err := f.file.Truncate(fileSize); err != nil { + // Expand the file if needed. + if int64(fr.End) > f.fileSize { + // Round the new file size up to be chunk-aligned. + newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask + if err := f.file.Truncate(newFileSize); err != nil { return platform.FileRange{}, err } - f.fileSize = fileSize + f.fileSize = newFileSize f.mappingsMu.Lock() oldMappings := f.mappings.Load().([]uintptr) - newMappings := make([]uintptr, fileSize>>chunkShift) + newMappings := make([]uintptr, newFileSize>>chunkShift) copy(newMappings, oldMappings) f.mappings.Store(newMappings) f.mappingsMu.Unlock() } // Mark selected pages as in use. - fr := platform.FileRange{start, end} if f.opts.ManualZeroing { if err := f.forEachMappingSlice(fr, func(bs []byte) { for i := range bs { @@ -453,49 +426,71 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage)) } - if minUnallocatedPage < start { - f.minUnallocatedPage = minUnallocatedPage - } else { - // start was the first unallocated page. The next must be - // somewhere beyond end. - f.minUnallocatedPage = end - } - return fr, nil } -// findUnallocatedRange returns the first unallocated page in usage of the -// specified length and alignment beginning at page start and the first single -// unallocated page. -func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uint64, uint64) { - // Only searched until the first page is found. - firstPage := start - foundFirstPage := false - alignMask := alignment - 1 - for seg := usage.LowerBoundSegment(start); seg.Ok(); seg = seg.NextSegment() { - r := seg.Range() +// findAvailableRange returns an available range in the usageSet. +// +// Note that scanning for available slots takes place from end first backwards, +// then forwards. This heuristic has important consequence for how sequential +// mappings can be merged in the host VMAs, given that addresses for both +// application and sentry mappings are allocated top-down (from higher to +// lower addresses). The file is also grown expoentially in order to create +// space for mappings to be allocated downwards. +// +// Precondition: alignment must be a power of 2. +func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) { + alignmentMask := alignment - 1 + for gap := usage.UpperBoundGap(uint64(fileSize)); gap.Ok(); gap = gap.PrevLargeEnoughGap(length) { + // Start searching only at end of file. + end := gap.End() + if end > uint64(fileSize) { + end = uint64(fileSize) + } - if !foundFirstPage && r.Start > firstPage { - foundFirstPage = true + // Start at the top and align downwards. + start := end - length + if start > end { + break // Underflow. } + start &^= alignmentMask - if start >= r.End { - // start was rounded up to an alignment boundary from the end - // of a previous segment and is now beyond r.End. + // Is the gap still sufficient? + if start < gap.Start() { continue } - // This segment represents allocated or reclaimable pages; only the - // range from start to the segment's beginning is allocatable, and the - // next allocatable range begins after the segment. - if r.Start > start && r.Start-start >= length { - break + + // Allocate in the given gap. + return platform.FileRange{start, start + length}, true + } + + // Check that it's possible to fit this allocation at the end of a file of any size. + min := usage.LastGap().Start() + min = (min + alignmentMask) &^ alignmentMask + if min+length < min { + // Overflow. + return platform.FileRange{}, false + } + + // Determine the minimum file size required to fit this allocation at its end. + for { + if fileSize >= 2*fileSize { + // Is this because it's initially empty? + if fileSize == 0 { + fileSize += chunkSize + } else { + // fileSize overflow. + return platform.FileRange{}, false + } + } else { + // Double the current fileSize. + fileSize *= 2 } - start = (r.End + alignMask) &^ alignMask - if !foundFirstPage { - firstPage = r.End + start := (uint64(fileSize) - length) &^ alignmentMask + if start >= min { + return platform.FileRange{start, start + length}, true } } - return start, firstPage } // AllocateAndFill allocates memory of the given kind and fills it by calling @@ -616,6 +611,7 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { } val.refs-- if val.refs == 0 { + f.reclaim.Add(seg.Range(), reclaimSetValue{}) freed = true // Reclassify memory as System, until it's freed by the reclaim // goroutine. @@ -628,10 +624,6 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { f.usage.MergeAdjacent(fr) if freed { - if fr.Start < f.minReclaimablePage { - // We've freed at least one lower page. - f.minReclaimablePage = fr.Start - } f.reclaimable = true f.reclaimCond.Signal() } @@ -1030,6 +1022,7 @@ func (f *MemoryFile) String() string { // for allocation. func (f *MemoryFile) runReclaim() { for { + // N.B. We must call f.markReclaimed on the returned FrameRange. fr, ok := f.findReclaimable() if !ok { break @@ -1085,6 +1078,10 @@ func (f *MemoryFile) runReclaim() { } } +// findReclaimable finds memory that has been marked for reclaim. +// +// Note that there returned range will be removed from tracking. It +// must be reclaimed (removed from f.usage) at this point. func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { f.mu.Lock() defer f.mu.Unlock() @@ -1103,18 +1100,15 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { } f.reclaimCond.Wait() } - // Allocate returns the first usable range in offset order and is - // currently a linear scan, so reclaiming from the beginning of the - // file minimizes the expected latency of Allocate. - for seg := f.usage.LowerBoundSegment(f.minReclaimablePage); seg.Ok(); seg = seg.NextSegment() { - if seg.ValuePtr().refs == 0 { - f.minReclaimablePage = seg.End() - return seg.Range(), true - } + // Allocate works from the back of the file inwards, so reclaim + // preserves this order to minimize the cost of the search. + if seg := f.reclaim.LastSegment(); seg.Ok() { + fr := seg.Range() + f.reclaim.Remove(seg) + return fr, true } - // No pages are reclaimable. + // Nothing is reclaimable. f.reclaimable = false - f.minReclaimablePage = maxPage } } @@ -1122,8 +1116,8 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) { f.mu.Lock() defer f.mu.Unlock() seg := f.usage.FindSegment(fr.Start) - // All of fr should be mapped to a single uncommitted reclaimable segment - // accounted to System. + // All of fr should be mapped to a single uncommitted reclaimable + // segment accounted to System. if !seg.Ok() { panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage)) } @@ -1137,14 +1131,10 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) { }); got != want { panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage)) } - // Deallocate reclaimed pages. Even though all of seg is reclaimable, the - // caller of markReclaimed may not have decommitted it, so we can only mark - // fr as reclaimed. + // Deallocate reclaimed pages. Even though all of seg is reclaimable, + // the caller of markReclaimed may not have decommitted it, so we can + // only mark fr as reclaimed. f.usage.Remove(f.usage.Isolate(seg, fr)) - if fr.Start < f.minUnallocatedPage { - // We've deallocated at least one lower page. - f.minUnallocatedPage = fr.Start - } } // StartEvictions requests that f evict all evictable allocations. It does not @@ -1255,3 +1245,27 @@ func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetVal func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) { return evictableRangeSetValue{}, evictableRangeSetValue{} } + +// reclaimSetValue is the value type of reclaimSet. +type reclaimSetValue struct{} + +type reclaimSetFunctions struct{} + +func (reclaimSetFunctions) MinKey() uint64 { + return 0 +} + +func (reclaimSetFunctions) MaxKey() uint64 { + return math.MaxUint64 +} + +func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) { +} + +func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { + return reclaimSetValue{}, true +} + +func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { + return reclaimSetValue{}, reclaimSetValue{} +} diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go index 293f22c6b..b5b68eb52 100644 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ b/pkg/sentry/pgalloc/pgalloc_test.go @@ -23,39 +23,49 @@ import ( const ( page = usermem.PageSize hugepage = usermem.HugePageSize + topPage = (1 << 63) - page ) func TestFindUnallocatedRange(t *testing.T) { for _, test := range []struct { - desc string - usage *usageSegmentDataSlices - start uint64 - length uint64 - alignment uint64 - unallocated uint64 - minUnallocated uint64 + desc string + usage *usageSegmentDataSlices + fileSize int64 + length uint64 + alignment uint64 + start uint64 + expectFail bool }{ { - desc: "Initial allocation succeeds", - usage: &usageSegmentDataSlices{}, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, + desc: "Initial allocation succeeds", + usage: &usageSegmentDataSlices{}, + length: page, + alignment: page, + start: chunkSize - page, // Grows by chunkSize, allocate down. }, { - desc: "Allocation begins at start of file", + desc: "Allocation finds empty space at start of file", usage: &usageSegmentDataSlices{ Start: []uint64{page}, End: []uint64{2 * page}, Values: []usageInfo{{refs: 1}}, }, - start: 0, - length: page, - alignment: page, - unallocated: 0, - minUnallocated: 0, + fileSize: 2 * page, + length: page, + alignment: page, + start: 0, + }, + { + desc: "Allocation finds empty space at end of file", + usage: &usageSegmentDataSlices{ + Start: []uint64{0}, + End: []uint64{page}, + Values: []usageInfo{{refs: 1}}, + }, + fileSize: 2 * page, + length: page, + alignment: page, + start: page, }, { desc: "In-use frames are not allocatable", @@ -64,11 +74,10 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 2 * page}, Values: []usageInfo{{refs: 1}, {refs: 2}}, }, - start: 0, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, + fileSize: 2 * page, + length: page, + alignment: page, + start: 3 * page, // Double fileSize, allocate top-down. }, { desc: "Reclaimable frames are not allocatable", @@ -77,11 +86,10 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 2 * page, 3 * page}, Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}}, }, - start: 0, - length: page, - alignment: page, - unallocated: 3 * page, - minUnallocated: 3 * page, + fileSize: 3 * page, + length: page, + alignment: page, + start: 5 * page, // Double fileSize, grow down. }, { desc: "Gaps between in-use frames are allocatable", @@ -90,11 +98,10 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 3 * page}, Values: []usageInfo{{refs: 1}, {refs: 1}}, }, - start: 0, - length: page, - alignment: page, - unallocated: page, - minUnallocated: page, + fileSize: 3 * page, + length: page, + alignment: page, + start: page, }, { desc: "Inadequately-sized gaps are rejected", @@ -103,14 +110,13 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, 3 * page}, Values: []usageInfo{{refs: 1}, {refs: 1}}, }, - start: 0, - length: 2 * page, - alignment: page, - unallocated: 3 * page, - minUnallocated: page, + fileSize: 3 * page, + length: 2 * page, + alignment: page, + start: 4 * page, // Double fileSize, grow down. }, { - desc: "Hugepage alignment is honored", + desc: "Alignment is honored at end of file", usage: &usageSegmentDataSlices{ Start: []uint64{0, hugepage + page}, // Hugepage-sized gap here that shouldn't be allocated from @@ -118,37 +124,95 @@ func TestFindUnallocatedRange(t *testing.T) { End: []uint64{page, hugepage + 2*page}, Values: []usageInfo{{refs: 1}, {refs: 1}}, }, - start: 0, - length: hugepage, - alignment: hugepage, - unallocated: 2 * hugepage, - minUnallocated: page, + fileSize: hugepage + 2*page, + length: hugepage, + alignment: hugepage, + start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down. + }, + { + desc: "Alignment is honored before end of file", + usage: &usageSegmentDataSlices{ + Start: []uint64{0, 2*hugepage + page}, + // Page will need to be shifted down from top. + End: []uint64{page, 2*hugepage + 2*page}, + Values: []usageInfo{{refs: 1}, {refs: 1}}, + }, + fileSize: 2*hugepage + 2*page, + length: hugepage, + alignment: hugepage, + start: hugepage, }, { - desc: "Pages before start ignored", + desc: "Allocations are compact if possible", usage: &usageSegmentDataSlices{ Start: []uint64{page, 3 * page}, End: []uint64{2 * page, 4 * page}, Values: []usageInfo{{refs: 1}, {refs: 2}}, }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, + fileSize: 4 * page, + length: page, + alignment: page, + start: 2 * page, + }, + { + desc: "Top-down allocation within one gap", + usage: &usageSegmentDataSlices{ + Start: []uint64{page, 4 * page, 7 * page}, + End: []uint64{2 * page, 5 * page, 8 * page}, + Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}}, + }, + fileSize: 8 * page, + length: page, + alignment: page, + start: 6 * page, + }, + { + desc: "Top-down allocation between multiple gaps", + usage: &usageSegmentDataSlices{ + Start: []uint64{page, 3 * page, 5 * page}, + End: []uint64{2 * page, 4 * page, 6 * page}, + Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}}, + }, + fileSize: 6 * page, + length: page, + alignment: page, + start: 4 * page, }, { - desc: "start may be in the middle of segment", + desc: "Top-down allocation with large top gap", usage: &usageSegmentDataSlices{ - Start: []uint64{0, 3 * page}, + Start: []uint64{page, 3 * page}, End: []uint64{2 * page, 4 * page}, Values: []usageInfo{{refs: 1}, {refs: 2}}, }, - start: page, - length: page, - alignment: page, - unallocated: 2 * page, - minUnallocated: 2 * page, + fileSize: 8 * page, + length: page, + alignment: page, + start: 7 * page, + }, + { + desc: "Gaps found with possible overflow", + usage: &usageSegmentDataSlices{ + Start: []uint64{page, topPage - page}, + End: []uint64{2 * page, topPage}, + Values: []usageInfo{{refs: 1}, {refs: 1}}, + }, + fileSize: topPage, + length: page, + alignment: page, + start: topPage - 2*page, + }, + { + desc: "Overflow detected", + usage: &usageSegmentDataSlices{ + Start: []uint64{page}, + End: []uint64{topPage}, + Values: []usageInfo{{refs: 1}}, + }, + fileSize: topPage, + length: 2 * page, + alignment: page, + expectFail: true, }, } { t.Run(test.desc, func(t *testing.T) { @@ -156,12 +220,18 @@ func TestFindUnallocatedRange(t *testing.T) { if err := usage.ImportSortedSlices(test.usage); err != nil { t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err) } - unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment) - if unallocated != test.unallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated) + fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment) + if !test.expectFail && !ok { + t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) + } + if test.expectFail && ok { + t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) + } + if ok && fr.Start != test.start { + t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start) } - if minUnallocated != test.minUnallocated { - t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated) + if ok && fr.End != test.start+test.length { + t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length) } }) } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 159f7eafd..4792454c4 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -6,8 +6,8 @@ go_library( name = "kvm", srcs = [ "address_space.go", - "allocator.go", "bluepill.go", + "bluepill_allocator.go", "bluepill_amd64.go", "bluepill_amd64.s", "bluepill_amd64_unsafe.go", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index be213bfe8..faf1d5e1c 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -26,16 +26,15 @@ import ( // dirtySet tracks vCPUs for invalidation. type dirtySet struct { - vCPUs []uint64 + vCPUMasks []uint64 } // forEach iterates over all CPUs in the dirty set. +// +//go:nosplit func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) { - m.mu.RLock() - defer m.mu.RUnlock() - - for index := range ds.vCPUs { - mask := atomic.SwapUint64(&ds.vCPUs[index], 0) + for index := range ds.vCPUMasks { + mask := atomic.SwapUint64(&ds.vCPUMasks[index], 0) if mask != 0 { for bit := 0; bit < 64; bit++ { if mask&(1<<uint64(bit)) == 0 { @@ -54,7 +53,7 @@ func (ds *dirtySet) mark(c *vCPU) bool { index := uint64(c.id) / 64 bit := uint64(1) << uint(c.id%64) - oldValue := atomic.LoadUint64(&ds.vCPUs[index]) + oldValue := atomic.LoadUint64(&ds.vCPUMasks[index]) if oldValue&bit != 0 { return false // Not clean. } @@ -62,7 +61,7 @@ func (ds *dirtySet) mark(c *vCPU) bool { // Set the bit unilaterally, and ensure that a flush takes place. Note // that it's possible for races to occur here, but since the flush is // taking place long after these lines there's no race in practice. - atomicbitops.OrUint64(&ds.vCPUs[index], bit) + atomicbitops.OrUint64(&ds.vCPUMasks[index], bit) return true // Previously clean. } @@ -113,7 +112,12 @@ type hostMapEntry struct { length uintptr } -func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) { +// mapLocked maps the given host entry. +// +// +checkescape:hard,stack +// +//go:nosplit +func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) { for m.length > 0 { physical, length, ok := translateToPhysical(m.addr) if !ok { @@ -133,18 +137,10 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // important; if the pagetable mappings were installed before // ensuring the physical pages were available, then some other // thread could theoretically access them. - // - // Due to the way KVM's shadow paging implementation works, - // modifications to the page tables while in host mode may not - // be trapped, leading to the shadow pages being out of sync. - // Therefore, we need to ensure that we are in guest mode for - // page table modifications. See the call to bluepill, below. - as.machine.retryInGuest(func() { - inv = as.pageTables.Map(addr, length, pagetables.MapOpts{ - AccessType: at, - User: true, - }, physical) || inv - }) + inv = as.pageTables.Map(addr, length, pagetables.MapOpts{ + AccessType: at, + User: true, + }, physical) || inv m.addr += length m.length -= length addr += usermem.Addr(length) @@ -176,6 +172,10 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform. return err } + // See block in mapLocked. + as.pageTables.Allocator.(*allocator).cpu = as.machine.Get() + defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu) + // Map the mappings in the sentry's address space (guest physical memory) // into the application's address space (guest virtual memory). inv := false @@ -190,7 +190,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform. _ = s[i] // Touch to commit. } } - prev := as.mapHost(addr, hostMapEntry{ + + // See bluepill_allocator.go. + bluepill(as.pageTables.Allocator.(*allocator).cpu) + + // Perform the mapping. + prev := as.mapLocked(addr, hostMapEntry{ addr: b.Addr(), length: uintptr(b.Len()), }, at) @@ -204,17 +209,27 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform. return nil } +// unmapLocked is an escape-checked wrapped around Unmap. +// +// +checkescape:hard,stack +// +//go:nosplit +func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool { + return as.pageTables.Unmap(addr, uintptr(length)) +} + // Unmap unmaps the given range by calling pagetables.PageTables.Unmap. func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) { as.mu.Lock() defer as.mu.Unlock() - // See above re: retryInGuest. - var prev bool - as.machine.retryInGuest(func() { - prev = as.pageTables.Unmap(addr, uintptr(length)) || prev - }) - if prev { + // See above & bluepill_allocator.go. + as.pageTables.Allocator.(*allocator).cpu = as.machine.Get() + defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu) + bluepill(as.pageTables.Allocator.(*allocator).cpu) + + if prev := as.unmapLocked(addr, length); prev { + // Invalidate all active vCPUs. as.invalidate() // Recycle any freed intermediate pages. @@ -227,7 +242,7 @@ func (as *addressSpace) Release() { as.Unmap(0, ^uint64(0)) // Free all pages from the allocator. - as.pageTables.Allocator.(allocator).base.Drain() + as.pageTables.Allocator.(*allocator).base.Drain() // Drop all cached machine references. as.machine.dropPageTables(as.pageTables) diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go index 3f35414bb..9485e1301 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/bluepill_allocator.go @@ -21,56 +21,80 @@ import ( ) type allocator struct { - base *pagetables.RuntimeAllocator + base pagetables.RuntimeAllocator + + // cpu must be set prior to any pagetable operation. + // + // Due to the way KVM's shadow paging implementation works, + // modifications to the page tables while in host mode may not be + // trapped, leading to the shadow pages being out of sync. Therefore, + // we need to ensure that we are in guest mode for page table + // modifications. See the call to bluepill, below. + cpu *vCPU } // newAllocator is used to define the allocator. -func newAllocator() allocator { - return allocator{ - base: pagetables.NewRuntimeAllocator(), - } +func newAllocator() *allocator { + a := new(allocator) + a.base.Init() + return a } // NewPTEs implements pagetables.Allocator.NewPTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) NewPTEs() *pagetables.PTEs { - return a.base.NewPTEs() +func (a *allocator) NewPTEs() *pagetables.PTEs { + ptes := a.base.NewPTEs() // escapes: bluepill below. + if a.cpu != nil { + bluepill(a.cpu) + } + return ptes } // PhysicalFor returns the physical address for a set of PTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { +func (a *allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { virtual := a.base.PhysicalFor(ptes) physical, _, ok := translateToPhysical(virtual) if !ok { - panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) + panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) // escapes: panic. } return physical } // LookupPTEs implements pagetables.Allocator.LookupPTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { +func (a *allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { - panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) + panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) // escapes: panic. } return a.base.LookupPTEs(virtualStart + (physical - physicalStart)) } // FreePTEs implements pagetables.Allocator.FreePTEs. // +// +checkescape:all +// //go:nosplit -func (a allocator) FreePTEs(ptes *pagetables.PTEs) { - a.base.FreePTEs(ptes) +func (a *allocator) FreePTEs(ptes *pagetables.PTEs) { + a.base.FreePTEs(ptes) // escapes: bluepill below. + if a.cpu != nil { + bluepill(a.cpu) + } } // Recycle implements pagetables.Allocator.Recycle. // //go:nosplit -func (a allocator) Recycle() { +func (a *allocator) Recycle() { a.base.Recycle() } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 133c2203d..ddc1554d5 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -63,6 +63,8 @@ func bluepillArchEnter(context *arch.SignalContext64) *vCPU { // KernelSyscall handles kernel syscalls. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelSyscall() { regs := c.Registers() @@ -72,13 +74,15 @@ func (c *vCPU) KernelSyscall() { // We only trigger a bluepill entry in the bluepill function, and can // therefore be guaranteed that there is no floating point state to be // loaded on resuming from halt. We only worry about saving on exit. - ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) + ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no. ring0.Halt() - ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment. + ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment. } // KernelException handles kernel exceptions. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelException(vector ring0.Vector) { regs := c.Registers() @@ -89,9 +93,9 @@ func (c *vCPU) KernelException(vector ring0.Vector) { regs.Rip = 0 } // See above. - ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) + ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no. ring0.Halt() - ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment. + ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment. } // bluepillArchExit is called during bluepillEnter. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index c215d443c..83643c602 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -66,6 +66,8 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { // KernelSyscall handles kernel syscalls. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelSyscall() { regs := c.Registers() @@ -88,6 +90,8 @@ func (c *vCPU) KernelSyscall() { // KernelException handles kernel exceptions. // +// +checkescape:all +// //go:nosplit func (c *vCPU) KernelException(vector ring0.Vector) { regs := c.Registers() diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 9add7c944..c025aa0bb 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. @@ -64,6 +64,8 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { // signal stack. It should only execute raw system calls and functions that are // explicitly marked go:nosplit. // +// +checkescape:all +// //go:nosplit func bluepillHandler(context unsafe.Pointer) { // Sanitize the registers; interrupts must always be disabled. @@ -82,7 +84,8 @@ func bluepillHandler(context unsafe.Pointer) { } for { - switch _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0); errno { + _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) // escapes: no. + switch errno { case 0: // Expected case. case syscall.EINTR: // First, we process whatever pending signal @@ -90,7 +93,7 @@ func bluepillHandler(context unsafe.Pointer) { // currently, all signals are masked and the signal // must have been delivered directly to this thread. timeout := syscall.Timespec{} - sig, _, errno := syscall.RawSyscall6( + sig, _, errno := syscall.RawSyscall6( // escapes: no. syscall.SYS_RT_SIGTIMEDWAIT, uintptr(unsafe.Pointer(&bounceSignalMask)), 0, // siginfo. @@ -125,7 +128,7 @@ func bluepillHandler(context unsafe.Pointer) { // MMIO exit we receive EFAULT from the run ioctl. We // always inject an NMI here since we may be in kernel // mode and have interrupts disabled. - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_NMI, 0); errno != 0 { diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index f1afc74dc..6c54712d1 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -52,16 +52,19 @@ type machine struct { // available is notified when vCPUs are available. available sync.Cond - // vCPUs are the machine vCPUs. + // vCPUsByTID are the machine vCPUs. // // These are populated dynamically. - vCPUs map[uint64]*vCPU + vCPUsByTID map[uint64]*vCPU // vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID. - vCPUsByID map[int]*vCPU + vCPUsByID []*vCPU // maxVCPUs is the maximum number of vCPUs supported by the machine. maxVCPUs int + + // nextID is the next vCPU ID. + nextID uint32 } const ( @@ -137,9 +140,8 @@ type dieState struct { // // Precondition: mu must be held. func (m *machine) newVCPU() *vCPU { - id := len(m.vCPUs) - // Create the vCPU. + id := int(atomic.AddUint32(&m.nextID, 1) - 1) fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id)) if errno != 0 { panic(fmt.Sprintf("error creating new vCPU: %v", errno)) @@ -176,11 +178,7 @@ func (m *machine) newVCPU() *vCPU { // newMachine returns a new VM context. func newMachine(vm int) (*machine, error) { // Create the machine. - m := &machine{ - fd: vm, - vCPUs: make(map[uint64]*vCPU), - vCPUsByID: make(map[int]*vCPU), - } + m := &machine{fd: vm} m.available.L = &m.mu m.kernel.Init(ring0.KernelOpts{ PageTables: pagetables.New(newAllocator()), @@ -194,6 +192,10 @@ func newMachine(vm int) (*machine, error) { } log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) + // Create the vCPUs map/slices. + m.vCPUsByTID = make(map[uint64]*vCPU) + m.vCPUsByID = make([]*vCPU, m.maxVCPUs) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -274,6 +276,8 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. +// +//go:nosplit func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) @@ -304,7 +308,11 @@ func (m *machine) Destroy() { runtime.SetFinalizer(m, nil) // Destroy vCPUs. - for _, c := range m.vCPUs { + for _, c := range m.vCPUsByID { + if c == nil { + continue + } + // Ensure the vCPU is not still running in guest mode. This is // possible iff teardown has been done by other threads, and // somehow a single thread has not executed any system calls. @@ -337,7 +345,7 @@ func (m *machine) Get() *vCPU { tid := procid.Current() // Check for an exact match. - if c := m.vCPUs[tid]; c != nil { + if c := m.vCPUsByTID[tid]; c != nil { c.lock() m.mu.RUnlock() return c @@ -356,7 +364,7 @@ func (m *machine) Get() *vCPU { tid = procid.Current() // Recheck for an exact match. - if c := m.vCPUs[tid]; c != nil { + if c := m.vCPUsByTID[tid]; c != nil { c.lock() m.mu.Unlock() return c @@ -364,10 +372,10 @@ func (m *machine) Get() *vCPU { for { // Scan for an available vCPU. - for origTID, c := range m.vCPUs { + for origTID, c := range m.vCPUsByTID { if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) { - delete(m.vCPUs, origTID) - m.vCPUs[tid] = c + delete(m.vCPUsByTID, origTID) + m.vCPUsByTID[tid] = c m.mu.Unlock() c.loadSegments(tid) return c @@ -375,17 +383,17 @@ func (m *machine) Get() *vCPU { } // Create a new vCPU (maybe). - if len(m.vCPUs) < m.maxVCPUs { + if int(m.nextID) < m.maxVCPUs { c := m.newVCPU() c.lock() - m.vCPUs[tid] = c + m.vCPUsByTID[tid] = c m.mu.Unlock() c.loadSegments(tid) return c } // Scan for something not in user mode. - for origTID, c := range m.vCPUs { + for origTID, c := range m.vCPUsByTID { if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) { continue } @@ -403,8 +411,8 @@ func (m *machine) Get() *vCPU { } // Steal the vCPU. - delete(m.vCPUs, origTID) - m.vCPUs[tid] = c + delete(m.vCPUsByTID, origTID) + m.vCPUsByTID[tid] = c m.mu.Unlock() c.loadSegments(tid) return c @@ -431,7 +439,7 @@ func (m *machine) Put(c *vCPU) { // newDirtySet returns a new dirty set. func (m *machine) newDirtySet() *dirtySet { return &dirtySet{ - vCPUs: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64), + vCPUMasks: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64), } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 923ce3909..acc823ba6 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -51,9 +51,10 @@ func (m *machine) initArchState() error { recover() debug.SetPanicOnFault(old) }() - m.retryInGuest(func() { - ring0.SetCPUIDFaulting(true) - }) + c := m.Get() + defer m.Put(c) + bluepill(c) + ring0.SetCPUIDFaulting(true) return nil } @@ -89,8 +90,8 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) { defer m.mu.Unlock() // Clear from all PCIDs. - for _, c := range m.vCPUs { - if c.PCIDs != nil { + for _, c := range m.vCPUsByID { + if c != nil && c.PCIDs != nil { c.PCIDs.Drop(pt) } } @@ -335,29 +336,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) } } -// retryInGuest runs the given function in guest mode. -// -// If the function does not complete in guest mode (due to execution of a -// system call due to a GC stall, for example), then it will be retried. The -// given function must be idempotent as a result of the retry mechanism. -func (m *machine) retryInGuest(fn func()) { - c := m.Get() - defer m.Put(c) - for { - c.ClearErrorCode() // See below. - bluepill(c) // Force guest mode. - fn() // Execute the given function. - _, user := c.ErrorCode() - if user { - // If user is set, then we haven't bailed back to host - // mode via a kernel exception or system call. We - // consider the full function to have executed in guest - // mode and we can return. - break - } - } -} - // On x86 platform, the flags for "setMemoryRegion" can always be set as 0. // There is no need to return read-only physicalRegions. func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 7156c245f..290f035dd 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -154,7 +154,7 @@ func (c *vCPU) setUserRegisters(uregs *userRegs) error { // //go:nosplit func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_GET_REGS, diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index e42505542..f3bf973de 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -60,6 +60,12 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { if !vr.accessType.Write && vr.accessType.Read { rdonlyRegions = append(rdonlyRegions, vr.region) } + + // TODO(gvisor.dev/issue/2686): PROT_NONE should be specially treated. + // Workaround: treated as rdonly temporarily. + if !vr.accessType.Write && !vr.accessType.Read && !vr.accessType.Execute { + rdonlyRegions = append(rdonlyRegions, vr.region) + } }) for _, r := range rdonlyRegions { @@ -100,7 +106,7 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) { defer m.mu.Unlock() // Clear from all PCIDs. - for _, c := range m.vCPUs { + for _, c := range m.vCPUsByID { if c.PCIDs != nil { c.PCIDs.Drop(pt) } diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index f04be2ab5..9f86f6a7a 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. @@ -115,7 +115,7 @@ func (a *atomicAddressSpace) get() *addressSpace { // //go:nosplit func (c *vCPU) notify() { - _, _, errno := syscall.RawSyscall6( + _, _, errno := syscall.RawSyscall6( // escapes: no. syscall.SYS_FUTEX, uintptr(unsafe.Pointer(&c.state)), linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG, diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index 2ae6b9f9d..0bee995e4 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 900c0bba7..021693791 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -31,23 +31,39 @@ type defaultHooks struct{} // KernelSyscall implements Hooks.KernelSyscall. // +// +checkescape:all +// //go:nosplit -func (defaultHooks) KernelSyscall() { Halt() } +func (defaultHooks) KernelSyscall() { + Halt() +} // KernelException implements Hooks.KernelException. // +// +checkescape:all +// //go:nosplit -func (defaultHooks) KernelException(Vector) { Halt() } +func (defaultHooks) KernelException(Vector) { + Halt() +} // kernelSyscall is a trampoline. // +// +checkescape:hard,stack +// //go:nosplit -func kernelSyscall(c *CPU) { c.hooks.KernelSyscall() } +func kernelSyscall(c *CPU) { + c.hooks.KernelSyscall() +} // kernelException is a trampoline. // +// +checkescape:hard,stack +// //go:nosplit -func kernelException(c *CPU, vector Vector) { c.hooks.KernelException(vector) } +func kernelException(c *CPU, vector Vector) { + c.hooks.KernelException(vector) +} // Init initializes a new CPU. // diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index 0feff8778..d37981dbf 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -178,6 +178,8 @@ func IsCanonical(addr uint64) bool { // // Precondition: the Rip, Rsp, Fs and Gs registers must be canonical. // +// +checkescape:all +// //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID) @@ -192,9 +194,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { // Perform the switch. swapgs() // GS will be swapped on return. - WriteFS(uintptr(regs.Fs_base)) // Set application FS. - WriteGS(uintptr(regs.Gs_base)) // Set application GS. - LoadFloatingPoint(switchOpts.FloatingPointState) // Copy in floating point. + WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. + WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. + LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point. jumpToKernel() // Switch to upper half. writeCR3(uintptr(userCR3)) // Change to user address space. if switchOpts.FullRestore { @@ -204,8 +206,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { } writeCR3(uintptr(kernelCR3)) // Return to kernel address space. jumpToUser() // Return to lower half. - SaveFloatingPoint(switchOpts.FloatingPointState) // Copy out floating point. - WriteFS(uintptr(c.registers.Fs_base)) // Restore kernel FS. + SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point. + WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return } diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 444a83913..a6345010d 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -38,6 +38,12 @@ func SaveVRegs(*byte) // LoadVRegs loads V0-V31 registers. func LoadVRegs(*byte) +// GetTLS returns the value of TPIDR_EL0 register. +func GetTLS() (value uint64) + +// SetTLS writes the TPIDR_EL0 value. +func SetTLS(value uint64) + // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 0e6a6235b..b63e14b41 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,16 @@ #include "funcdata.h" #include "textflag.h" +TEXT ·GetTLS(SB),NOSPLIT,$0-8 + MRS TPIDR_EL0, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·SetTLS(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + MSR R1, TPIDR_EL0 + RET + TEXT ·CPACREL1(SB),NOSPLIT,$0-8 WORD $0xd5381041 // MRS CPACR_EL1, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go index 23fd5c352..8d75b7599 100644 --- a/pkg/sentry/platform/ring0/pagetables/allocator.go +++ b/pkg/sentry/platform/ring0/pagetables/allocator.go @@ -53,9 +53,14 @@ type RuntimeAllocator struct { // NewRuntimeAllocator returns an allocator that uses runtime allocation. func NewRuntimeAllocator() *RuntimeAllocator { - return &RuntimeAllocator{ - used: make(map[*PTEs]struct{}), - } + r := new(RuntimeAllocator) + r.Init() + return r +} + +// Init initializes a RuntimeAllocator. +func (r *RuntimeAllocator) Init() { + r.used = make(map[*PTEs]struct{}) } // Recycle returns freed pages to the pool. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 87e88e97d..7f18ac296 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -86,6 +86,8 @@ func (*mapVisitor) requiresSplit() bool { return true } // // Precondition: addr & length must be page-aligned, their sum must not overflow. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { if !opts.AccessType.Any() { @@ -128,6 +130,8 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // Precondition: addr & length must be page-aligned. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { w := unmapWalker{ @@ -162,6 +166,8 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // Precondition: addr & length must be page-aligned. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool { w := emptyWalker{ @@ -197,6 +203,8 @@ func (*lookupVisitor) requiresSplit() bool { return false } // Lookup returns the physical address for the given virtual address. // +// +checkescape:hard,stack +// //go:nosplit func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) { mask := uintptr(usermem.PageSize - 1) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index b49433326..c11e82c10 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -555,7 +555,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if uint64(src.NumBytes()) != srcs.NumBytes() { return 0, nil } - if srcs.IsEmpty() { + if srcs.IsEmpty() && len(controlBuf) == 0 { return 0, nil } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 789bb94c8..66015e2bc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -64,6 +64,8 @@ const enableLogging = false var emptyFilter = stack.IPHeaderFilter{ Dst: "\x00\x00\x00\x00", DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", } // nflog logs messages related to the writing and reading of iptables. @@ -142,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen } func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - ipt := stk.IPTables() - table, ok := ipt.Tables[tablename.String()] + table, ok := stk.IPTables().GetTable(tablename.String()) if !ok { return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stk *stack.Stack) { - ipt := stack.DefaultTables() - - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func FillIPTablesMetadata(stk *stack.Stack) { + stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { + // In order to fill in the metadata, we have to translate ipt from its + // netstack format to Linux's giant-binary-blob format. + for name, table := range tables { + _, metadata, err := convertNetstackToBinary(name, table) + if err != nil { + panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + } + table.SetMetadata(metadata) + tables[name] = table } - table.SetMetadata(metadata) - ipt.Tables[name] = table - } - - stk.SetIPTables(ipt) + }) } // convertNetstackToBinary converts the iptables as stored in netstack to the @@ -214,11 +212,16 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI } copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) + copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP } + if rule.Filter.SrcInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + } if rule.Filter.OutputInterfaceInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } @@ -566,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, NumEntries: replace.NumEntries, Size: replace.Size, }) - ipt.Tables[replace.Name.String()] = table - stk.SetIPTables(ipt) + stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil } @@ -737,6 +738,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) if n == -1 { @@ -755,6 +759,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, OutputInterface: ifname, OutputInterfaceMask: ifnameMask, OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, @@ -765,15 +772,13 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { // The following features are supported: // - Protocol // - Dst and DstMask + // - Src and SrcMask // - The inverse destination IP check flag // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.Src != emptyInetAddr || - iptip.SrcMask != emptyInetAddr || - iptip.InputInterface != emptyInterface || + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || iptip.Flags != 0 || iptip.InverseFlags&^inverseMask != 0 diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 3863293c7..1b4e0ad79 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -111,7 +111,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 57a1e1c12..4f98ee2d5 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) if netHeader.TransportProtocol() != header.TCPProtocolNumber { @@ -111,36 +111,10 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa return false, false } - // Now we need the transport header. However, this may not have been set - // yet. - // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the stack.Check codepath as matchers are - // added. - var tcpHeader header.TCP - if pkt.TransportHeader != nil { - tcpHeader = header.TCP(pkt.TransportHeader) - } else { - var length int - if hook == stack.Prerouting { - // The network header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // There's no valid TCP header here, so we hotdrop the - // packet. - return false, true - } - h := header.IPv4(hdr) - pkt.NetworkHeader = hdr - length = int(h.HeaderLength()) - } - // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - // There's no valid TCP header here, so we hotdrop the - // packet. - return false, true - } - tcpHeader = header.TCP(hdr[length:]) + tcpHeader := header.TCP(pkt.TransportHeader) + if len(tcpHeader) < header.TCPMinimumSize { + // There's no valid TCP header here, so we drop the packet immediately. + return false, true } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index cfa9e621d..3f20fc891 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved @@ -110,36 +110,10 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa return false, false } - // Now we need the transport header. However, this may not have been set - // yet. - // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the stack.Check codepath as matchers are - // added. - var udpHeader header.UDP - if pkt.TransportHeader != nil { - udpHeader = header.UDP(pkt.TransportHeader) - } else { - var length int - if hook == stack.Prerouting { - // The network header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // There's no valid UDP header here, so we hotdrop the - // packet. - return false, true - } - h := header.IPv4(hdr) - pkt.NetworkHeader = hdr - length = int(h.HeaderLength()) - } - // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - // There's no valid UDP header here, so we hotdrop the - // packet. - return false, true - } - udpHeader = header.UDP(hdr[length:]) + udpHeader := header.UDP(pkt.TransportHeader) + if len(udpHeader) < header.UDPMinimumSize { + // There's no valid UDP header here, so we drop the packet immediately. + return false, true } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 333e0042e..8f0f5466e 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -50,5 +50,6 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 60df51dae..e1e0c5931 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -33,6 +33,7 @@ import ( "syscall" "time" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" @@ -719,6 +720,14 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool defer s.EventUnregister(&e) if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { + // TCP unlike UDP returns EADDRNOTAVAIL when it can't + // find an available local ephemeral port. + if err == tcpip.ErrNoPortAvailable { + return syserr.ErrAddressNotAvailable + } + } + return syserr.TranslateNetstackError(err) } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f5fa18136..9b44c2b89 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (stack.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func (s *Stack) FillIPTablesMetadata() { + netfilter.FillIPTablesMetadata(s.Stack) } // Resume implements inet.Stack.Resume. diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index ce5b94ee7..09c6d3b27 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -252,7 +252,7 @@ func (e *connectionedEndpoint) Close() { // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { if ce.Type() != e.stype { - return syserr.ErrConnectionRefused + return syserr.ErrWrongProtocolForSocket } // Check if ce is e to avoid a deadlock. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 5b29e9d7f..c4c9db81b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -417,7 +417,18 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool defer ep.Release() // Connect the server endpoint. - return s.ep.Connect(t, ep) + err = s.ep.Connect(t, ep) + + if err == syserr.ErrWrongProtocolForSocket { + // Linux for abstract sockets returns ErrConnectionRefused + // instead of ErrWrongProtocolForSocket. + path, _ := extractPath(sockaddr) + if len(path) > 0 && path[0] == 0 { + err = syserr.ErrConnectionRefused + } + } + + return err } // Write implements fs.FileOperations.Write. diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 39f2b79ec..77c78889d 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -80,6 +80,12 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB } } + if total > 0 { + // On Linux, inotify behavior is not very consistent with splice(2). We try + // our best to emulate Linux for very basic calls to splice, where for some + // reason, events are generated for output files, but not input files. + outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0) + } return total, err } diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 2de5e3422..c24946160 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -207,7 +207,11 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, si return syserror.EOPNOTSUPP } - return d.Inode.SetXattr(t, d, name, value, flags) + if err := d.Inode.SetXattr(t, d, name, value, flags); err != nil { + return err + } + d.InotifyEvent(linux.IN_ATTRIB, 0) + return nil } func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { @@ -418,7 +422,11 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error { return syserror.EOPNOTSUPP } - return d.Inode.RemoveXattr(t, d, name) + if err := d.Inode.RemoveXattr(t, d, name); err != nil { + return err + } + d.InotifyEvent(linux.IN_ATTRIB, 0) + return nil } // LINT.ThenChange(vfs2/xattr.go) diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index f882ef840..c0d005247 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -12,9 +12,11 @@ go_library( "filesystem.go", "fscontext.go", "getdents.go", + "inotify.go", "ioctl.go", "memfd.go", "mmap.go", + "mount.go", "path.go", "pipe.go", "poll.go", @@ -22,6 +24,7 @@ go_library( "setstat.go", "signal.go", "socket.go", + "splice.go", "stat.go", "stat_amd64.go", "stat_arm64.go", diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go new file mode 100644 index 000000000..7d50b6a16 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go @@ -0,0 +1,134 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const allFlags = linux.IN_NONBLOCK | linux.IN_CLOEXEC + +// InotifyInit1 implements the inotify_init1() syscalls. +func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + flags := args[0].Int() + if flags&^allFlags != 0 { + return 0, nil, syserror.EINVAL + } + + ino, err := vfs.NewInotifyFD(t, t.Kernel().VFS(), uint32(flags)) + if err != nil { + return 0, nil, err + } + defer ino.DecRef() + + fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{ + CloseOnExec: flags&linux.IN_CLOEXEC != 0, + }) + + if err != nil { + return 0, nil, err + } + + return uintptr(fd), nil, nil +} + +// InotifyInit implements the inotify_init() syscalls. +func InotifyInit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + args[0].Value = 0 + return InotifyInit1(t, args) +} + +// fdToInotify resolves an fd to an inotify object. If successful, the file will +// have an extra ref and the caller is responsible for releasing the ref. +func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, error) { + f := t.GetFileVFS2(fd) + if f == nil { + // Invalid fd. + return nil, nil, syserror.EBADF + } + + ino, ok := f.Impl().(*vfs.Inotify) + if !ok { + // Not an inotify fd. + f.DecRef() + return nil, nil, syserror.EINVAL + } + + return ino, f, nil +} + +// InotifyAddWatch implements the inotify_add_watch() syscall. +func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + addr := args[1].Pointer() + mask := args[2].Uint() + + // "EINVAL: The given event mask contains no valid events." + // -- inotify_add_watch(2) + if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 { + return 0, nil, syserror.EINVAL + } + + // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link." + // -- inotify(7) + follow := followFinalSymlink + if mask&linux.IN_DONT_FOLLOW == 0 { + follow = nofollowFinalSymlink + } + + ino, f, err := fdToInotify(t, fd) + if err != nil { + return 0, nil, err + } + defer f.DecRef() + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + if mask&linux.IN_ONLYDIR != 0 { + path.Dir = true + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, follow) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{}) + if err != nil { + return 0, nil, err + } + defer d.DecRef() + + fd = ino.AddWatch(d.Dentry(), mask) + return uintptr(fd), nil, err +} + +// InotifyRmWatch implements the inotify_rm_watch() syscall. +func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + wd := args[1].Int() + + ino, f, err := fdToInotify(t, fd) + if err != nil { + return 0, nil, err + } + defer f.DecRef() + return 0, nil, ino.RmWatch(wd) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go new file mode 100644 index 000000000..adeaa39cc --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -0,0 +1,145 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Mount implements Linux syscall mount(2). +func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + sourceAddr := args[0].Pointer() + targetAddr := args[1].Pointer() + typeAddr := args[2].Pointer() + flags := args[3].Uint64() + dataAddr := args[4].Pointer() + + // For null-terminated strings related to mount(2), Linux copies in at most + // a page worth of data. See fs/namespace.c:copy_mount_string(). + fsType, err := t.CopyInString(typeAddr, usermem.PageSize) + if err != nil { + return 0, nil, err + } + source, err := t.CopyInString(sourceAddr, usermem.PageSize) + if err != nil { + return 0, nil, err + } + + targetPath, err := copyInPath(t, targetAddr) + if err != nil { + return 0, nil, err + } + + data := "" + if dataAddr != 0 { + // In Linux, a full page is always copied in regardless of null + // character placement, and the address is passed to each file system. + // Most file systems always treat this data as a string, though, and so + // do all of the ones we implement. + data, err = t.CopyInString(dataAddr, usermem.PageSize) + if err != nil { + return 0, nil, err + } + } + + // Ignore magic value that was required before Linux 2.4. + if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL { + flags = flags &^ linux.MS_MGC_MSK + } + + // Must have CAP_SYS_ADMIN in the current mount namespace's associated user + // namespace. + creds := t.Credentials() + if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) { + return 0, nil, syserror.EPERM + } + + const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND | + linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE | + linux.MS_UNBINDABLE | linux.MS_MOVE + + // Silently allow MS_NOSUID, since we don't implement set-id bits + // anyway. + const unsupportedFlags = linux.MS_NODEV | + linux.MS_NODIRATIME | linux.MS_STRICTATIME + + // Linux just allows passing any flags to mount(2) - it won't fail when + // unknown or unsupported flags are passed. Since we don't implement + // everything, we fail explicitly on flags that are unimplemented. + if flags&(unsupportedOps|unsupportedFlags) != 0 { + return 0, nil, syserror.EINVAL + } + + var opts vfs.MountOptions + if flags&linux.MS_NOATIME == linux.MS_NOATIME { + opts.Flags.NoATime = true + } + if flags&linux.MS_NOEXEC == linux.MS_NOEXEC { + opts.Flags.NoExec = true + } + if flags&linux.MS_RDONLY == linux.MS_RDONLY { + opts.ReadOnly = true + } + opts.GetFilesystemOptions.Data = data + + target, err := getTaskPathOperation(t, linux.AT_FDCWD, targetPath, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer target.Release() + + return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts) +} + +// Umount2 implements Linux syscall umount2(2). +func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + flags := args[1].Int() + + // Must have CAP_SYS_ADMIN in the mount namespace's associated user + // namespace. + // + // Currently, this is always the init task's user namespace. + creds := t.Credentials() + if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) { + return 0, nil, syserror.EPERM + } + + const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE + if flags&unsupported != 0 { + return 0, nil, syserror.EINVAL + } + + path, err := copyInPath(t, addr) + if err != nil { + return 0, nil, err + } + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink) + if err != nil { + return 0, nil, err + } + defer tpop.Release() + + opts := vfs.UmountOptions{ + Flags: uint32(flags), + } + + return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index 3a7ef24f5..7f9debd4a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -93,11 +93,17 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { n, err := file.Read(t, dst, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } @@ -128,6 +134,9 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return total, err } @@ -248,11 +257,17 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { n, err := file.PRead(t, dst, offset, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } @@ -283,6 +298,9 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return total, err } @@ -345,11 +363,17 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { n, err := file.Write(t, src, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return n, err } @@ -380,6 +404,9 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return total, err } @@ -500,11 +527,17 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { n, err := file.PWrite(t, src, offset, opts) if err != syserror.ErrWouldBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { + if n > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return n, err } @@ -535,6 +568,9 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o } file.EventUnregister(&w) + if total > 0 { + file.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + } return total, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go new file mode 100644 index 000000000..945a364a7 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -0,0 +1,291 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Splice implements Linux syscall splice(2). +func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + inFD := args[0].Int() + inOffsetPtr := args[1].Pointer() + outFD := args[2].Int() + outOffsetPtr := args[3].Pointer() + count := int64(args[4].SizeT()) + flags := args[5].Int() + + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } + + // Check for invalid flags. + if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get file descriptions. + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + + // Check that both files support the required directionality. + if !inFile.IsReadable() || !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0) + + // At least one file description must represent a pipe. + inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + if !inIsPipe && !outIsPipe { + return 0, nil, syserror.EINVAL + } + + // Copy in offsets. + inOffset := int64(-1) + if inOffsetPtr != 0 { + if inIsPipe { + return 0, nil, syserror.ESPIPE + } + if inFile.Options().DenyPRead { + return 0, nil, syserror.EINVAL + } + if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil { + return 0, nil, err + } + if inOffset < 0 { + return 0, nil, syserror.EINVAL + } + } + outOffset := int64(-1) + if outOffsetPtr != 0 { + if outIsPipe { + return 0, nil, syserror.ESPIPE + } + if outFile.Options().DenyPWrite { + return 0, nil, syserror.EINVAL + } + if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil { + return 0, nil, err + } + if outOffset < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Move data. + var ( + n int64 + err error + inCh chan struct{} + outCh chan struct{} + ) + for { + // If both input and output are pipes, delegate to the pipe + // implementation. Otherwise, exactly one end is a pipe, which we + // ensure is consistently ordered after the non-pipe FD's locks by + // passing the pipe FD as usermem.IO to the non-pipe end. + switch { + case inIsPipe && outIsPipe: + n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) + case inIsPipe: + if outOffset != -1 { + n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{}) + outOffset += n + } else { + n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{}) + } + case outIsPipe: + if inOffset != -1 { + n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{}) + inOffset += n + } else { + n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) + } + } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, eventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } + } + if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, eventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } + } + } + + // Copy updated offsets out. + if inOffsetPtr != 0 { + if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil { + return 0, nil, err + } + } + if outOffsetPtr != 0 { + if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + + // On Linux, inotify behavior is not very consistent with splice(2). We try + // our best to emulate Linux for very basic calls to splice, where for some + // reason, events are generated for output files, but not input files. + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// Tee implements Linux syscall tee(2). +func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + inFD := args[0].Int() + outFD := args[1].Int() + count := int64(args[2].SizeT()) + flags := args[3].Int() + + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } + + // Check for invalid flags. + if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { + return 0, nil, syserror.EINVAL + } + + // Get file descriptions. + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + + // Check that both files support the required directionality. + if !inFile.IsReadable() || !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0) + + // Both file descriptions must represent pipes. + inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + if !inIsPipe || !outIsPipe { + return 0, nil, syserror.EINVAL + } + + // Copy data. + var ( + inCh chan struct{} + outCh chan struct{} + ) + for { + n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 { + return uintptr(n), nil, nil + } + if err != syserror.ErrWouldBlock || nonBlock { + return 0, nil, err + } + + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the tee operation. + if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, eventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err := t.Block(inCh); err != nil { + return 0, nil, err + } + } + if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, eventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err := t.Block(outCh); err != nil { + return 0, nil, err + } + } + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index a332d01bd..7b6e7571a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -90,8 +90,8 @@ func Override() { s.Table[138] = syscalls.Supported("fstatfs", Fstatfs) s.Table[161] = syscalls.Supported("chroot", Chroot) s.Table[162] = syscalls.Supported("sync", Sync) - delete(s.Table, 165) // mount - delete(s.Table, 166) // umount2 + s.Table[165] = syscalls.Supported("mount", Mount) + s.Table[166] = syscalls.Supported("umount2", Umount2) delete(s.Table, 187) // readahead s.Table[188] = syscalls.Supported("setxattr", Setxattr) s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr) @@ -116,9 +116,9 @@ func Override() { s.Table[232] = syscalls.Supported("epoll_wait", EpollWait) s.Table[233] = syscalls.Supported("epoll_ctl", EpollCtl) s.Table[235] = syscalls.Supported("utimes", Utimes) - delete(s.Table, 253) // inotify_init - delete(s.Table, 254) // inotify_add_watch - delete(s.Table, 255) // inotify_rm_watch + s.Table[253] = syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil) + s.Table[254] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[255] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil) s.Table[257] = syscalls.Supported("openat", Openat) s.Table[258] = syscalls.Supported("mkdirat", Mkdirat) s.Table[259] = syscalls.Supported("mknodat", Mknodat) @@ -134,8 +134,8 @@ func Override() { s.Table[269] = syscalls.Supported("faccessat", Faccessat) s.Table[270] = syscalls.Supported("pselect", Pselect) s.Table[271] = syscalls.Supported("ppoll", Ppoll) - delete(s.Table, 275) // splice - delete(s.Table, 276) // tee + s.Table[275] = syscalls.Supported("splice", Splice) + s.Table[276] = syscalls.Supported("tee", Tee) s.Table[277] = syscalls.Supported("sync_file_range", SyncFileRange) s.Table[280] = syscalls.Supported("utimensat", Utimensat) s.Table[281] = syscalls.Supported("epoll_pwait", EpollPwait) @@ -151,7 +151,7 @@ func Override() { s.Table[291] = syscalls.Supported("epoll_create1", EpollCreate1) s.Table[292] = syscalls.Supported("dup3", Dup3) s.Table[293] = syscalls.Supported("pipe2", Pipe2) - delete(s.Table, 294) // inotify_init1 + s.Table[294] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil) s.Table[295] = syscalls.Supported("preadv", Preadv) s.Table[296] = syscalls.Supported("pwritev", Pwritev) s.Table[299] = syscalls.Supported("recvmmsg", RecvMMsg) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 94d69c1cc..774cc66cc 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -15,6 +15,18 @@ go_template_instance( }, ) +go_template_instance( + name = "event_list", + out = "event_list.go", + package = "vfs", + prefix = "event", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Event", + "Linker": "*Event", + }, +) + go_library( name = "vfs", srcs = [ @@ -25,11 +37,13 @@ go_library( "device.go", "epoll.go", "epoll_interest_list.go", + "event_list.go", "file_description.go", "file_description_impl_util.go", "filesystem.go", "filesystem_impl_util.go", "filesystem_type.go", + "inotify.go", "mount.go", "mount_unsafe.go", "options.go", @@ -57,6 +71,7 @@ go_library( "//pkg/sentry/limits", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", + "//pkg/sentry/uniqueid", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 9aa133bcb..66f3105bd 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -39,8 +39,8 @@ Mount references are held by: - Mount: Each referenced Mount holds a reference on its parent, which is the mount containing its mount point. -- VirtualFilesystem: A reference is held on each Mount that has not been - umounted. +- VirtualFilesystem: A reference is held on each Mount that has been connected + to a mount point, but not yet umounted. MountNamespace and FileDescription references are held by users of VFS. The expectation is that each `kernel.Task` holds a reference on its corresponding diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index caf770fd5..b7c6b60b8 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -297,3 +297,15 @@ func (d *anonDentry) TryIncRef() bool { func (d *anonDentry) DecRef() { // no-op } + +// InotifyWithParent implements DentryImpl.InotifyWithParent. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *anonDentry) InotifyWithParent(events uint32, cookie uint32, et EventType) {} + +// Watches implements DentryImpl.Watches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *anonDentry) Watches() *Watches { + return nil +} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 8624dbd5d..24af13eb1 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -103,6 +103,22 @@ type DentryImpl interface { // DecRef decrements the Dentry's reference count. DecRef() + + // InotifyWithParent notifies all watches on the targets represented by this + // dentry and its parent. The parent's watches are notified first, followed + // by this dentry's. + // + // InotifyWithParent automatically adds the IN_ISDIR flag for dentries + // representing directories. + // + // Note that the events may not actually propagate up to the user, depending + // on the event masks. + InotifyWithParent(events uint32, cookie uint32, et EventType) + + // Watches returns the set of inotify watches for the file corresponding to + // the Dentry. Dentries that are hard links to the same underlying file + // share the same watches. + Watches() *Watches } // IncRef increments d's reference count. @@ -133,6 +149,17 @@ func (d *Dentry) isMounted() bool { return atomic.LoadUint32(&d.mounts) != 0 } +// InotifyWithParent notifies all watches on the inodes for this dentry and +// its parent of events. +func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et EventType) { + d.impl.InotifyWithParent(events, cookie, et) +} + +// Watches returns the set of inotify watches associated with d. +func (d *Dentry) Watches() *Watches { + return d.impl.Watches() +} + // The following functions are exported so that filesystem implementations can // use them. The vfs package, and users of VFS, should not call these // functions. diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index cfabd936c..bb294563d 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -210,6 +210,11 @@ func (fd *FileDescription) VirtualDentry() VirtualDentry { return fd.vd } +// Options returns the options passed to fd.Init(). +func (fd *FileDescription) Options() FileDescriptionOptions { + return fd.opts +} + // StatusFlags returns file description status flags, as for fcntl(F_GETFL). func (fd *FileDescription) StatusFlags() uint32 { return atomic.LoadUint32(&fd.statusFlags) diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 286510195..8882fa84a 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -43,7 +43,7 @@ type Dentry struct { // IsAncestorDentry returns true if d is an ancestor of d2; that is, d is // either d2's parent or an ancestor of d2's parent. func IsAncestorDentry(d, d2 *Dentry) bool { - for { + for d2 != nil { // Stop at root, where d2.parent == nil. if d2.parent == d { return true } @@ -52,6 +52,7 @@ func IsAncestorDentry(d, d2 *Dentry) bool { } d2 = d2.parent } + return false } // ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d. diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go new file mode 100644 index 000000000..05a3051a4 --- /dev/null +++ b/pkg/sentry/vfs/inotify.go @@ -0,0 +1,697 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "bytes" + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// inotifyEventBaseSize is the base size of linux's struct inotify_event. This +// must be a power 2 for rounding below. +const inotifyEventBaseSize = 16 + +// EventType defines different kinds of inotfiy events. +// +// The way events are labelled appears somewhat arbitrary, but they must match +// Linux so that IN_EXCL_UNLINK behaves as it does in Linux. +type EventType uint8 + +// PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and +// FSNOTIFY_EVENT_INODE in Linux. +const ( + PathEvent EventType = iota + InodeEvent EventType = iota +) + +// Inotify represents an inotify instance created by inotify_init(2) or +// inotify_init1(2). Inotify implements FileDescriptionImpl. +// +// Lock ordering: +// Inotify.mu -> Watches.mu -> Inotify.evMu +// +// +stateify savable +type Inotify struct { + vfsfd FileDescription + FileDescriptionDefaultImpl + DentryMetadataFileDescriptionImpl + + // Unique identifier for this inotify instance. We don't just reuse the + // inotify fd because fds can be duped. These should not be exposed to the + // user, since we may aggressively reuse an id on S/R. + id uint64 + + // queue is used to notify interested parties when the inotify instance + // becomes readable or writable. + queue waiter.Queue `state:"nosave"` + + // evMu *only* protects the events list. We need a separate lock while + // queuing events: using mu may violate lock ordering, since at that point + // the calling goroutine may already hold Watches.mu. + evMu sync.Mutex `state:"nosave"` + + // A list of pending events for this inotify instance. Protected by evMu. + events eventList + + // A scratch buffer, used to serialize inotify events. Allocate this + // ahead of time for the sake of performance. Protected by evMu. + scratch []byte + + // mu protects the fields below. + mu sync.Mutex `state:"nosave"` + + // nextWatchMinusOne is used to allocate watch descriptors on this Inotify + // instance. Note that Linux starts numbering watch descriptors from 1. + nextWatchMinusOne int32 + + // Map from watch descriptors to watch objects. + watches map[int32]*Watch +} + +var _ FileDescriptionImpl = (*Inotify)(nil) + +// NewInotifyFD constructs a new Inotify instance. +func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) (*FileDescription, error) { + // O_CLOEXEC affects file descriptors, so it must be handled outside of vfs. + flags &^= linux.O_CLOEXEC + if flags&^linux.O_NONBLOCK != 0 { + return nil, syserror.EINVAL + } + + id := uniqueid.GlobalFromContext(ctx) + vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id)) + defer vd.DecRef() + fd := &Inotify{ + id: id, + scratch: make([]byte, inotifyEventBaseSize), + watches: make(map[int32]*Watch), + } + if err := fd.vfsfd.Init(fd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{ + UseDentryMetadata: true, + DenyPRead: true, + DenyPWrite: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// Release implements FileDescriptionImpl.Release. Release removes all +// watches and frees all resources for an inotify instance. +func (i *Inotify) Release() { + // We need to hold i.mu to avoid a race with concurrent calls to + // Inotify.handleDeletion from Watches. There's no risk of Watches + // accessing this Inotify after the destructor ends, because we remove all + // references to it below. + i.mu.Lock() + defer i.mu.Unlock() + for _, w := range i.watches { + // Remove references to the watch from the watches set on the target. We + // don't need to worry about the references from i.watches, since this + // file description is about to be destroyed. + w.set.Remove(i.id) + } +} + +// EventRegister implements waiter.Waitable. +func (i *Inotify) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + i.queue.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable. +func (i *Inotify) EventUnregister(e *waiter.Entry) { + i.queue.EventUnregister(e) +} + +// Readiness implements waiter.Waitable.Readiness. +// +// Readiness indicates whether there are pending events for an inotify instance. +func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask { + ready := waiter.EventMask(0) + + i.evMu.Lock() + defer i.evMu.Unlock() + + if !i.events.Empty() { + ready |= waiter.EventIn + } + + return mask & ready +} + +// PRead implements FileDescriptionImpl. +func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// PWrite implements FileDescriptionImpl. +func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements FileDescriptionImpl.Write. +func (*Inotify) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Read implements FileDescriptionImpl.Read. +func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + if dst.NumBytes() < inotifyEventBaseSize { + return 0, syserror.EINVAL + } + + i.evMu.Lock() + defer i.evMu.Unlock() + + if i.events.Empty() { + // Nothing to read yet, tell caller to block. + return 0, syserror.ErrWouldBlock + } + + var writeLen int64 + for it := i.events.Front(); it != nil; { + // Advance `it` before the element is removed from the list, or else + // it.Next() will always be nil. + event := it + it = it.Next() + + // Does the buffer have enough remaining space to hold the event we're + // about to write out? + if dst.NumBytes() < int64(event.sizeOf()) { + if writeLen > 0 { + // Buffer wasn't big enough for all pending events, but we did + // write some events out. + return writeLen, nil + } + return 0, syserror.EINVAL + } + + // Linux always dequeues an available event as long as there's enough + // buffer space to copy it out, even if the copy below fails. Emulate + // this behaviour. + i.events.Remove(event) + + // Buffer has enough space, copy event to the read buffer. + n, err := event.CopyTo(ctx, i.scratch, dst) + if err != nil { + return 0, err + } + + writeLen += n + dst = dst.DropFirst64(n) + } + return writeLen, nil +} + +// Ioctl implements fs.FileOperations.Ioctl. +func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + switch args[1].Int() { + case linux.FIONREAD: + i.evMu.Lock() + defer i.evMu.Unlock() + var n uint32 + for e := i.events.Front(); e != nil; e = e.Next() { + n += uint32(e.sizeOf()) + } + var buf [4]byte + usermem.ByteOrder.PutUint32(buf[:], n) + _, err := uio.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{}) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +func (i *Inotify) queueEvent(ev *Event) { + i.evMu.Lock() + + // Check if we should coalesce the event we're about to queue with the last + // one currently in the queue. Events are coalesced if they are identical. + if last := i.events.Back(); last != nil { + if ev.equals(last) { + // "Coalesce" the two events by simply not queuing the new one. We + // don't need to raise a waiter.EventIn notification because no new + // data is available for reading. + i.evMu.Unlock() + return + } + } + + i.events.PushBack(ev) + + // Release mutex before notifying waiters because we don't control what they + // can do. + i.evMu.Unlock() + + i.queue.Notify(waiter.EventIn) +} + +// newWatchLocked creates and adds a new watch to target. +// +// Precondition: i.mu must be locked. +func (i *Inotify) newWatchLocked(target *Dentry, mask uint32) *Watch { + targetWatches := target.Watches() + w := &Watch{ + owner: i, + wd: i.nextWatchIDLocked(), + set: targetWatches, + mask: mask, + } + + // Hold the watch in this inotify instance as well as the watch set on the + // target. + i.watches[w.wd] = w + targetWatches.Add(w) + return w +} + +// newWatchIDLocked allocates and returns a new watch descriptor. +// +// Precondition: i.mu must be locked. +func (i *Inotify) nextWatchIDLocked() int32 { + i.nextWatchMinusOne++ + return i.nextWatchMinusOne +} + +// handleDeletion handles the deletion of the target of watch w. It removes w +// from i.watches and a watch removal event is generated. +func (i *Inotify) handleDeletion(w *Watch) { + i.mu.Lock() + _, found := i.watches[w.wd] + delete(i.watches, w.wd) + i.mu.Unlock() + + if found { + i.queueEvent(newEvent(w.wd, "", linux.IN_IGNORED, 0)) + } +} + +// AddWatch constructs a new inotify watch and adds it to the target. It +// returns the watch descriptor returned by inotify_add_watch(2). +func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { + // Note: Locking this inotify instance protects the result returned by + // Lookup() below. With the lock held, we know for sure the lookup result + // won't become stale because it's impossible for *this* instance to + // add/remove watches on target. + i.mu.Lock() + defer i.mu.Unlock() + + // Does the target already have a watch from this inotify instance? + if existing := target.Watches().Lookup(i.id); existing != nil { + newmask := mask + if mask&linux.IN_MASK_ADD != 0 { + // "Add (OR) events to watch mask for this pathname if it already + // exists (instead of replacing mask)." -- inotify(7) + newmask |= atomic.LoadUint32(&existing.mask) + } + atomic.StoreUint32(&existing.mask, newmask) + return existing.wd + } + + // No existing watch, create a new watch. + w := i.newWatchLocked(target, mask) + return w.wd +} + +// RmWatch looks up an inotify watch for the given 'wd' and configures the +// target to stop sending events to this inotify instance. +func (i *Inotify) RmWatch(wd int32) error { + i.mu.Lock() + + // Find the watch we were asked to removed. + w, ok := i.watches[wd] + if !ok { + i.mu.Unlock() + return syserror.EINVAL + } + + // Remove the watch from this instance. + delete(i.watches, wd) + + // Remove the watch from the watch target. + w.set.Remove(w.OwnerID()) + i.mu.Unlock() + + // Generate the event for the removal. + i.queueEvent(newEvent(wd, "", linux.IN_IGNORED, 0)) + + return nil +} + +// Watches is the collection of all inotify watches on a single file. +// +// +stateify savable +type Watches struct { + // mu protects the fields below. + mu sync.RWMutex `state:"nosave"` + + // ws is the map of active watches in this collection, keyed by the inotify + // instance id of the owner. + ws map[uint64]*Watch +} + +// Lookup returns the watch owned by an inotify instance with the given id. +// Returns nil if no such watch exists. +// +// Precondition: the inotify instance with the given id must be locked to +// prevent the returned watch from being concurrently modified or replaced in +// Inotify.watches. +func (w *Watches) Lookup(id uint64) *Watch { + w.mu.Lock() + defer w.mu.Unlock() + return w.ws[id] +} + +// Add adds watch into this set of watches. +// +// Precondition: the inotify instance with the given id must be locked. +func (w *Watches) Add(watch *Watch) { + w.mu.Lock() + defer w.mu.Unlock() + + owner := watch.OwnerID() + // Sanity check, we should never have two watches for one owner on the + // same target. + if _, exists := w.ws[owner]; exists { + panic(fmt.Sprintf("Watch collision with ID %+v", owner)) + } + if w.ws == nil { + w.ws = make(map[uint64]*Watch) + } + w.ws[owner] = watch +} + +// Remove removes a watch with the given id from this set of watches and +// releases it. The caller is responsible for generating any watch removal +// event, as appropriate. The provided id must match an existing watch in this +// collection. +// +// Precondition: the inotify instance with the given id must be locked. +func (w *Watches) Remove(id uint64) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.ws == nil { + // This watch set is being destroyed. The thread executing the + // destructor is already in the process of deleting all our watches. We + // got here with no references on the target because we raced with the + // destructor notifying all the watch owners of destruction. See the + // comment in Watches.HandleDeletion for why this race exists. + return + } + + if _, ok := w.ws[id]; !ok { + // While there's technically no problem with silently ignoring a missing + // watch, this is almost certainly a bug. + panic(fmt.Sprintf("Attempt to remove a watch, but no watch found with provided id %+v.", id)) + } + delete(w.ws, id) +} + +// Notify queues a new event with all watches in this set. +func (w *Watches) Notify(name string, events, cookie uint32, et EventType) { + w.NotifyWithExclusions(name, events, cookie, et, false) +} + +// NotifyWithExclusions queues a new event with watches in this set. Watches +// with IN_EXCL_UNLINK are skipped if the event is coming from a child that +// has been unlinked. +func (w *Watches) NotifyWithExclusions(name string, events, cookie uint32, et EventType, unlinked bool) { + // N.B. We don't defer the unlocks because Notify is in the hot path of + // all IO operations, and the defer costs too much for small IO + // operations. + w.mu.RLock() + for _, watch := range w.ws { + if unlinked && watch.ExcludeUnlinkedChildren() && et == PathEvent { + continue + } + watch.Notify(name, events, cookie) + } + w.mu.RUnlock() +} + +// HandleDeletion is called when the watch target is destroyed to emit +// the appropriate events. +func (w *Watches) HandleDeletion() { + w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent) + + // TODO(gvisor.dev/issue/1479): This doesn't work because maps are not copied + // by value. Ideally, we wouldn't have this circular locking so we can just + // notify of IN_DELETE_SELF in the same loop below. + // + // We can't hold w.mu while calling watch.handleDeletion to preserve lock + // ordering w.r.t to the owner inotify instances. Instead, atomically move + // the watches map into a local variable so we can iterate over it safely. + // + // Because of this however, it is possible for the watches' owners to reach + // this inode while the inode has no refs. This is still safe because the + // owners can only reach the inode until this function finishes calling + // watch.handleDeletion below and the inode is guaranteed to exist in the + // meantime. But we still have to be very careful not to rely on inode state + // that may have been already destroyed. + var ws map[uint64]*Watch + w.mu.Lock() + ws = w.ws + w.ws = nil + w.mu.Unlock() + + for _, watch := range ws { + // TODO(gvisor.dev/issue/1479): consider refactoring this. + watch.handleDeletion() + } +} + +// Watch represent a particular inotify watch created by inotify_add_watch. +// +// +stateify savable +type Watch struct { + // Inotify instance which owns this watch. + owner *Inotify + + // Descriptor for this watch. This is unique across an inotify instance. + wd int32 + + // set is the watch set containing this watch. It belongs to the target file + // of this watch. + set *Watches + + // Events being monitored via this watch. Must be accessed with atomic + // memory operations. + mask uint32 +} + +// OwnerID returns the id of the inotify instance that owns this watch. +func (w *Watch) OwnerID() uint64 { + return w.owner.id +} + +// ExcludeUnlinkedChildren indicates whether the watched object should continue +// to be notified of events of its children after they have been unlinked, e.g. +// for an open file descriptor. +// +// TODO(gvisor.dev/issue/1479): Implement IN_EXCL_UNLINK. +// We can do this by keeping track of the set of unlinked children in Watches +// to skip notification. +func (w *Watch) ExcludeUnlinkedChildren() bool { + return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK != 0 +} + +// Notify queues a new event on this watch. +func (w *Watch) Notify(name string, events uint32, cookie uint32) { + mask := atomic.LoadUint32(&w.mask) + if mask&events == 0 { + // We weren't watching for this event. + return + } + + // Event mask should include bits matched from the watch plus all control + // event bits. + unmaskableBits := ^uint32(0) &^ linux.IN_ALL_EVENTS + effectiveMask := unmaskableBits | mask + matchedEvents := effectiveMask & events + w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie)) +} + +// handleDeletion handles the deletion of w's target. +func (w *Watch) handleDeletion() { + w.owner.handleDeletion(w) +} + +// Event represents a struct inotify_event from linux. +// +// +stateify savable +type Event struct { + eventEntry + + wd int32 + mask uint32 + cookie uint32 + + // len is computed based on the name field is set automatically by + // Event.setName. It should be 0 when no name is set; otherwise it is the + // length of the name slice. + len uint32 + + // The name field has special padding requirements and should only be set by + // calling Event.setName. + name []byte +} + +func newEvent(wd int32, name string, events, cookie uint32) *Event { + e := &Event{ + wd: wd, + mask: events, + cookie: cookie, + } + if name != "" { + e.setName(name) + } + return e +} + +// paddedBytes converts a go string to a null-terminated c-string, padded with +// null bytes to a total size of 'l'. 'l' must be large enough for all the bytes +// in the 's' plus at least one null byte. +func paddedBytes(s string, l uint32) []byte { + if l < uint32(len(s)+1) { + panic("Converting string to byte array results in truncation, this can lead to buffer-overflow due to the missing null-byte!") + } + b := make([]byte, l) + copy(b, s) + + // b was zero-value initialized during make(), so the rest of the slice is + // already filled with null bytes. + + return b +} + +// setName sets the optional name for this event. +func (e *Event) setName(name string) { + // We need to pad the name such that the entire event length ends up a + // multiple of inotifyEventBaseSize. + unpaddedLen := len(name) + 1 + // Round up to nearest multiple of inotifyEventBaseSize. + e.len = uint32((unpaddedLen + inotifyEventBaseSize - 1) & ^(inotifyEventBaseSize - 1)) + // Make sure we haven't overflowed and wrapped around when rounding. + if unpaddedLen > int(e.len) { + panic("Overflow when rounding inotify event size, the 'name' field was too big.") + } + e.name = paddedBytes(name, e.len) +} + +func (e *Event) sizeOf() int { + s := inotifyEventBaseSize + int(e.len) + if s < inotifyEventBaseSize { + panic("overflow") + } + return s +} + +// CopyTo serializes this event to dst. buf is used as a scratch buffer to +// construct the output. We use a buffer allocated ahead of time for +// performance. buf must be at least inotifyEventBaseSize bytes. +func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) { + usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd)) + usermem.ByteOrder.PutUint32(buf[4:], e.mask) + usermem.ByteOrder.PutUint32(buf[8:], e.cookie) + usermem.ByteOrder.PutUint32(buf[12:], e.len) + + writeLen := 0 + + n, err := dst.CopyOut(ctx, buf) + if err != nil { + return 0, err + } + writeLen += n + dst = dst.DropFirst(n) + + if e.len > 0 { + n, err = dst.CopyOut(ctx, e.name) + if err != nil { + return 0, err + } + writeLen += n + } + + // Santiy check. + if writeLen != e.sizeOf() { + panic(fmt.Sprintf("Serialized unexpected amount of data for an event, expected %d, wrote %d.", e.sizeOf(), writeLen)) + } + + return int64(writeLen), nil +} + +func (e *Event) equals(other *Event) bool { + return e.wd == other.wd && + e.mask == other.mask && + e.cookie == other.cookie && + e.len == other.len && + bytes.Equal(e.name, other.name) +} + +// InotifyEventFromStatMask generates the appropriate events for an operation +// that set the stats specified in mask. +func InotifyEventFromStatMask(mask uint32) uint32 { + var ev uint32 + if mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_MODE) != 0 { + ev |= linux.IN_ATTRIB + } + if mask&linux.STATX_SIZE != 0 { + ev |= linux.IN_MODIFY + } + + if (mask & (linux.STATX_ATIME | linux.STATX_MTIME)) == (linux.STATX_ATIME | linux.STATX_MTIME) { + // Both times indicates a utime(s) call. + ev |= linux.IN_ATTRIB + } else if mask&linux.STATX_ATIME != 0 { + ev |= linux.IN_ACCESS + } else if mask&linux.STATX_MTIME != 0 { + mask |= linux.IN_MODIFY + } + return ev +} + +// InotifyRemoveChild sends the appriopriate notifications to the watch sets of +// the child being removed and its parent. +func InotifyRemoveChild(self, parent *Watches, name string) { + self.Notify("", linux.IN_ATTRIB, 0, InodeEvent) + parent.Notify(name, linux.IN_DELETE, 0, InodeEvent) + // TODO(gvisor.dev/issue/1479): implement IN_EXCL_UNLINK. +} + +// InotifyRename sends the appriopriate notifications to the watch sets of the +// file being renamed and its old/new parents. +func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, oldName, newName string, isDir bool) { + var dirEv uint32 + if isDir { + dirEv = linux.IN_ISDIR + } + cookie := uniqueid.InotifyCookie(ctx) + oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent) + newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent) + // Somewhat surprisingly, self move events do not have a cookie. + renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent) +} diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 02850b65c..32f901bd8 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -28,9 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// lastMountID is used to allocate mount ids. Must be accessed atomically. -var lastMountID uint64 - // A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem // (Mount.key.parent.fs) with a Dentry (Mount.root) from another Filesystem // (Mount.fs), which applies to path resolution in the context of a particular @@ -58,6 +55,10 @@ type Mount struct { // ID is the immutable mount ID. ID uint64 + // Flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except + // for MS_RDONLY which is tracked in "writers". Immutable. + Flags MountFlags + // key is protected by VirtualFilesystem.mountMu and // VirtualFilesystem.mounts.seq, and may be nil. References are held on // key.parent and key.point if they are not nil. @@ -84,10 +85,6 @@ type Mount struct { // umounted is true. umounted is protected by VirtualFilesystem.mountMu. umounted bool - // flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except - // for MS_RDONLY which is tracked in "writers". - flags MountFlags - // The lower 63 bits of writers is the number of calls to // Mount.CheckBeginWrite() that have not yet been paired with a call to // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. @@ -97,11 +94,11 @@ type Mount struct { func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount { mnt := &Mount{ - ID: atomic.AddUint64(&lastMountID, 1), + ID: atomic.AddUint64(&vfs.lastMountID, 1), + Flags: opts.Flags, vfs: vfs, fs: fs, root: root, - flags: opts.Flags, ns: mntns, refs: 1, } @@ -111,8 +108,17 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount return mnt } -// A MountNamespace is a collection of Mounts. -// +// Options returns a copy of the MountOptions currently applicable to mnt. +func (mnt *Mount) Options() MountOptions { + mnt.vfs.mountMu.Lock() + defer mnt.vfs.mountMu.Unlock() + return MountOptions{ + Flags: mnt.Flags, + ReadOnly: mnt.readOnly(), + } +} + +// A MountNamespace is a collection of Mounts.// // MountNamespaces are reference-counted. Unless otherwise specified, all // MountNamespace methods require that a reference is held. // @@ -120,6 +126,9 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount // // +stateify savable type MountNamespace struct { + // Owner is the usernamespace that owns this mount namespace. + Owner *auth.UserNamespace + // root is the MountNamespace's root mount. root is immutable. root *Mount @@ -148,7 +157,7 @@ type MountNamespace struct { func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { rft := vfs.getFilesystemType(fsTypeName) if rft == nil { - ctx.Warningf("Unknown filesystem: %s", fsTypeName) + ctx.Warningf("Unknown filesystem type: %s", fsTypeName) return nil, syserror.ENODEV } fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts) @@ -156,6 +165,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return nil, err } mntns := &MountNamespace{ + Owner: creds.UserNamespace, refs: 1, mountpoints: make(map[*Dentry]uint32), } @@ -175,26 +185,34 @@ func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, return newMount(vfs, fs, root, nil /* mntns */, opts), nil } -// MountAt creates and mounts a Filesystem configured by the given arguments. -func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { +// MountDisconnected creates a Filesystem configured by the given arguments, +// then returns a Mount representing it. The new Mount is not associated with +// any MountNamespace and is not connected to any other Mounts. +func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth.Credentials, source string, fsTypeName string, opts *MountOptions) (*Mount, error) { rft := vfs.getFilesystemType(fsTypeName) if rft == nil { - return syserror.ENODEV + return nil, syserror.ENODEV } if !opts.InternalMount && !rft.opts.AllowUserMount { - return syserror.ENODEV + return nil, syserror.ENODEV } fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { - return err + return nil, err } + defer root.DecRef() + defer fs.DecRef() + return vfs.NewDisconnectedMount(fs, root, opts) +} +// ConnectMountAt connects mnt at the path represented by target. +// +// Preconditions: mnt must be disconnected. +func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Credentials, mnt *Mount, target *PathOperation) error { // We can't hold vfs.mountMu while calling FilesystemImpl methods due to // lock ordering. vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{}) if err != nil { - root.DecRef() - fs.DecRef() return err } vfs.mountMu.Lock() @@ -204,8 +222,6 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia vd.dentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef() - root.DecRef() - fs.DecRef() return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and @@ -238,7 +254,6 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia // point and the mount root are directories, or neither are, and returns // ENOTDIR if this is not the case. mntns := vd.mount.ns - mnt := newMount(vfs, fs, root, mntns, opts) vfs.mounts.seq.BeginWrite() vfs.connectLocked(mnt, vd, mntns) vfs.mounts.seq.EndWrite() @@ -247,6 +262,19 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia return nil } +// MountAt creates and mounts a Filesystem configured by the given arguments. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { + mnt, err := vfs.MountDisconnected(ctx, creds, source, fsTypeName, opts) + if err != nil { + return err + } + defer mnt.DecRef() + if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil { + return err + } + return nil +} + // UmountAt removes the Mount at the given path. func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error { if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 { @@ -254,6 +282,9 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti } // MNT_FORCE is currently unimplemented except for the permission check. + // Force unmounting specifically requires CAP_SYS_ADMIN in the root user + // namespace, and not in the owner user namespace for the target mount. See + // fs/namespace.c:SYSCALL_DEFINE2(umount, ...) if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) { return syserror.EPERM } @@ -369,14 +400,22 @@ func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecu // references held by vd. // // Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. d.mu must be locked. mnt.parent() == nil. +// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt +// must not already be connected. func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) { + if checkInvariants { + if mnt.parent() != nil { + panic("VFS.connectLocked called on connected mount") + } + } + mnt.IncRef() // dropped by callers of umountRecursiveLocked mnt.storeKey(vd) if vd.mount.children == nil { vd.mount.children = make(map[*Mount]struct{}) } vd.mount.children[mnt] = struct{}{} atomic.AddUint32(&vd.dentry.mounts, 1) + mnt.ns = mntns mntns.mountpoints[vd.dentry]++ vfs.mounts.insertSeqed(mnt) vfsmpmounts, ok := vfs.mountpoints[vd.dentry] @@ -394,6 +433,11 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns // writer critical section. mnt.parent() != nil. func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { vd := mnt.loadKey() + if checkInvariants { + if vd.mount != nil { + panic("VFS.disconnectLocked called on disconnected mount") + } + } mnt.storeKey(VirtualDentry{}) delete(vd.mount.children, mnt) atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1 @@ -715,7 +759,10 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi if mnt.readOnly() { opts = "ro" } - if mnt.flags.NoExec { + if mnt.Flags.NoATime { + opts = ",noatime" + } + if mnt.Flags.NoExec { opts += ",noexec" } @@ -800,11 +847,12 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo if mnt.readOnly() { opts = "ro" } - if mnt.flags.NoExec { + if mnt.Flags.NoATime { + opts = ",noatime" + } + if mnt.Flags.NoExec { opts += ",noexec" } - // TODO(gvisor.dev/issue/1193): Add "noatime" if MS_NOATIME is - // set. fmt.Fprintf(buf, "%s ", opts) // (7) Optional fields: zero or more fields of the form "tag[:value]". diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index bc7581698..70f850ca4 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 53d364c5c..f223aeda8 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -75,6 +75,10 @@ type MknodOptions struct { type MountFlags struct { // NoExec is equivalent to MS_NOEXEC. NoExec bool + + // NoATime is equivalent to MS_NOATIME and indicates that the + // filesystem should not update access time in-place. + NoATime bool } // MountOptions contains options to VirtualFilesystem.MountAt(). diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 8d7f8f8af..9acca8bc7 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -82,6 +82,10 @@ type VirtualFilesystem struct { // mountpoints is analogous to Linux's mountpoint_hashtable. mountpoints map[*Dentry]map[*Mount]struct{} + // lastMountID is the last allocated mount ID. lastMountID is accessed + // using atomic memory operations. + lastMountID uint64 + // anonMount is a Mount, not included in mounts or mountpoints, // representing an anonFilesystem. anonMount is used to back // VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry(). @@ -401,7 +405,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential vfs.putResolvingPath(rp) if opts.FileExec { - if fd.Mount().flags.NoExec { + if fd.Mount().Flags.NoExec { fd.DecRef() return nil, syserror.EACCES } @@ -418,6 +422,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential } } + fd.Dentry().InotifyWithParent(linux.IN_OPEN, 0, PathEvent) return fd, nil } if !rp.handleError(err) { diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 101497ed6..e2894f9f5 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -77,7 +77,10 @@ var DefaultOpts = Opts{ // trigger it. const descheduleThreshold = 1 * time.Second -var stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected") +var ( + stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout") + stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected") +) // Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck. var stackDumpSameTaskPeriod = time.Minute @@ -220,6 +223,9 @@ func (w *Watchdog) waitForStart() { // We are fine. return } + + stuckStartup.Increment() + var buf bytes.Buffer buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout)) w.doAction(w.StartupTimeoutAction, false, &buf) @@ -328,8 +334,8 @@ func (w *Watchdog) reportStuckWatchdog() { } // doAction will take the given action. If the action is LogWarning, the stack -// is not always dumpped to the log to prevent log flooding. "forceStack" -// guarantees that the stack will be dumped regarless. +// is not always dumped to the log to prevent log flooding. "forceStack" +// guarantees that the stack will be dumped regardless. func (w *Watchdog) doAction(action Action, forceStack bool, msg *bytes.Buffer) { switch action { case LogWarning: diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 65bfcf778..f68c12620 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 921af9d63..2b1350135 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -47,6 +47,7 @@ go_library( "state.go", "stats.go", ], + marshal = False, stateify = False, visibility = ["//:sandbox"], deps = [ diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 0e35d7d17..d0d77e19c 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -39,6 +39,8 @@ go_library( "seqcount.go", "sync.go", ], + marshal = False, + stateify = False, ) go_test( diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go index ad4a3a37e..1d7780695 100644 --- a/pkg/sync/memmove_unsafe.go +++ b/pkg/sync/memmove_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go index 3dd15578b..dc034d561 100644 --- a/pkg/sync/mutex_unsafe.go +++ b/pkg/sync/mutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.15 +// +build !go1.16 // When updating the build constraint (above), check that syncMutex matches the // standard library sync.Mutex definition. diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go index ea6cdc447..995c0346e 100644 --- a/pkg/sync/rwmutex_unsafe.go +++ b/pkg/sync/rwmutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go index 112e0e604..ad271e1a0 100644 --- a/pkg/syncevent/waiter_unsafe.go +++ b/pkg/syncevent/waiter_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index e57d45f2a..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -22,7 +22,6 @@ go_test( size = "small", srcs = ["gonet_test.go"], library = ":gonet", - tags = ["flaky"], deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 6e0db2741..d82ed5205 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -335,6 +335,11 @@ func (c *TCPConn) Read(b []byte) (int, error) { deadline := c.readCancel() numRead := 0 + defer func() { + if numRead != 0 { + c.ep.ModerateRecvBuf(numRead) + } + }() for numRead != len(b) { if len(c.read) == 0 { var err error diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 76839eb92..62ac932bb 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -159,6 +159,11 @@ func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } +// More returns whether the more fragments flag is set. +func (b IPv4) More() bool { + return b.Flags()&IPv4FlagMoreFragments != 0 +} + // TTL returns the "TTL" field of the ipv4 header. func (b IPv4) TTL() uint8 { return b[ttl] diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 2c4591409..3499d8399 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -354,6 +354,13 @@ func (b IPv6FragmentExtHdr) ID() uint32 { return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:]) } +// IsAtomic returns whether the fragment header indicates an atomic fragment. An +// atomic fragment is a fragment that contains all the data required to +// reassemble a full packet. +func (b IPv6FragmentExtHdr) IsAtomic() bool { + return !b.More() && b.FragmentOffset() == 0 +} + // IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. // // The IPv6 payload may contain IPv6 extension headers before any upper layer diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 29454c4b9..4c6f808e5 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -66,6 +66,14 @@ const ( TCPOptionSACK = 5 ) +// Option Lengths. +const ( + TCPOptionMSSLength = 4 + TCPOptionTSLength = 10 + TCPOptionWSLength = 3 + TCPOptionSackPermittedLength = 2 +) + // TCPFields contains the fields of a TCP packet. It is used to describe the // fields of a packet that needs to be encoded. type TCPFields struct { @@ -494,14 +502,11 @@ func ParseTCPOptions(b []byte) TCPOptions { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeMSSOption(mss uint32, b []byte) int { - // mssOptionSize is the number of bytes in a valid MSS option. - const mssOptionSize = 4 - - if len(b) < mssOptionSize { + if len(b) < TCPOptionMSSLength { return 0 } - b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) - return mssOptionSize + b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss) + return TCPOptionMSSLength } // EncodeWSOption encodes the WS TCP option with the WS value in the @@ -509,10 +514,10 @@ func EncodeMSSOption(mss uint32, b []byte) int { // returns without encoding anything. It returns the number of bytes written to // the provided buffer. func EncodeWSOption(ws int, b []byte) int { - if len(b) < 3 { + if len(b) < TCPOptionWSLength { return 0 } - b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) + b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws) return int(b[1]) } @@ -521,10 +526,10 @@ func EncodeWSOption(ws int, b []byte) int { // just returns without encoding anything. It returns the number of bytes // written to the provided buffer. func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { - if len(b) < 10 { + if len(b) < TCPOptionTSLength { return 0 } - b[0], b[1] = TCPOptionTS, 10 + b[0], b[1] = TCPOptionTS, TCPOptionTSLength binary.BigEndian.PutUint32(b[2:], tsVal) binary.BigEndian.PutUint32(b[6:], tsEcr) return int(b[1]) @@ -535,11 +540,11 @@ func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { // encoding anything. It returns the number of bytes written to the provided // buffer. func EncodeSACKPermittedOption(b []byte) int { - if len(b) < 2 { + if len(b) < TCPOptionSackPermittedLength { return 0 } - b[0], b[1] = TCPOptionSACKPermitted, 2 + b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength return int(b[1]) } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 9bf67686d..20b183da0 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -181,13 +181,13 @@ func (e *Endpoint) NumQueued() int { } // InjectInbound injects an inbound packet. -func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -229,13 +229,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() p := PacketInfo{ - Pkt: &pkt, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index affa1bbdf..f34082e1a 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -387,7 +387,7 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) @@ -641,8 +641,8 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 3bfb15a8e..eaee7e5d7 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents stack.PacketBuffer + contents *stack.PacketBuffer } type context struct { @@ -103,7 +103,7 @@ func (c *context) cleanup() { } } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } @@ -179,7 +179,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), Hash: hash, @@ -295,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buffer.VectorisedView{}, }); err != nil { @@ -358,7 +358,7 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: stack.PacketBuffer{ + contents: &stack.PacketBuffer{ Data: buffer.View(b).ToVectorisedView(), LinkHeader: buffer.View(hdr), }, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index fe2bf3b0b..2dfd29aa9 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index cb4cbea69..f04738cfb 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,13 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), LinkHeader: buffer.View(eth), } pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -169,7 +169,7 @@ type recvMMsgDispatcher struct { // iovecs is an array of array of iovec records where each iovec base // pointer and length are initialzed to the corresponding view above, - // except when GSO is neabled then the first iovec in each array of + // except when GSO is enabled then the first iovec in each array of // iovecs points to a buffer for the vnet header which is stripped // before the views are passed up the stack for further processing. iovecs [][]syscall.Iovec @@ -296,12 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), LinkHeader: buffer.View(eth), } pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..568c6874f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a5478ce17..c69d6b7e9 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,8 +80,8 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate @@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts s // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 87c734c1f..0744f66d6 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) @@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewView(0).ToVectorisedView(), }) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 54432194d..b5dfb7850 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -102,8 +102,8 @@ func (q *queueDispatcher) dispatchLoop() { } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } // Attach implements stack.LinkEndpoint.Attach. @@ -146,7 +146,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. newRoute := r.Clone() @@ -154,7 +154,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] - if !d.q.enqueue(&pkt) { + if !d.q.enqueue(pkt) { return tcpip.ErrNoBufferSpace } d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index 0b5a6cf49..99313ee25 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0796d717e..0374a2441 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..28a2e88ba 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, @@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, }); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }) @@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: uu, }); err != want { @@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index da1c520ae..ae3186314 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -120,9 +120,9 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, &pkt) - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dumpPacket("recv", nil, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher @@ -208,8 +208,8 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, &pkt) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.dumpPacket("send", gso, protocol, pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 617446ea2..6bc9033d0 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.View(data).ToVectorisedView(), } if ethHdr != nil { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 2b3741276..949b3f2b2 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) e.dispatchGate.Leave() } @@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 54eb5322b..63bf40562 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatchCount++ } @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9d0797af7..7f27a840d 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -80,7 +80,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -94,16 +94,12 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } - h := header.ARP(v) +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.ARP(pkt.NetworkHeader) if !h.IsValid() { return } @@ -122,7 +118,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) fallthrough // also fill the cache from requests @@ -177,7 +173,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } @@ -209,6 +205,17 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(header.ARPSize) + return 0, false, true +} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 1646d9cde..66e67429c 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ Data: v.ToVectorisedView(), }) } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index f42abc4bb..2982450f8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -81,8 +81,8 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t } } -// Process processes an incoming fragment belonging to an ID -// and returns a complete packet when all the packets belonging to that ID have been received. +// Process processes an incoming fragment belonging to an ID and returns a +// complete packet when all the packets belonging to that ID have been received. func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4c20301c6..7c8fb3e0a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } @@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,7 +150,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -246,7 +246,11 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -289,9 +293,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -378,10 +382,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: vv, - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -444,17 +445,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: frag1.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: frag2.ToVectorisedView(), - }) + pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -487,7 +488,11 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -530,9 +535,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -644,12 +649,25 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: view[:len(view)-c.trunc].ToVectorisedView(), - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } }) } } + +// truncatedPacket returns a PacketBuffer based on a truncated view. If view, +// after truncation, is large enough to hold a network header, it makes part of +// view the packet's NetworkHeader and the rest its Data. Otherwise all of view +// becomes Data. +func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { + v := view[:len(view)-trunc] + if len(v) < netHdrLen { + return &stack.PacketBuffer{Data: v.ToVectorisedView()} + } + return &stack.PacketBuffer{ + NetworkHeader: v[:netHdrLen], + Data: v[netHdrLen:].ToVectorisedView(), + } +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..1b67aa066 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -24,7 +24,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -56,9 +56,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { received.Invalid.Increment() @@ -88,7 +91,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ Data: pkt.Data.Clone(nil), NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) @@ -102,7 +105,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: vv, TransportHeader: buffer.View(pkt), diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 64046cbbf..7e9f16c90 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,6 +21,7 @@ package ipv4 import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -129,7 +130,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { // packet's stated length matches the length of the header+payload. mtu // includes the IP header and options. This does not support the DontFragment // IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() @@ -169,7 +170,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, if i > 0 { newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -188,7 +189,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -202,7 +203,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr := pkt.Header startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: startOfHdr, Data: emptyVV, NetworkHeader: buffer.View(h), @@ -245,7 +246,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -253,43 +254,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than + // only NATted packets, but removing this check short circuits broadcasts + // before they are sent out to other hosts. if pkt.NatDone { - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. netHeader := header.IPv4(pkt.NetworkHeader) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) return nil } } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - + e.HandlePacket(&loopedR, pkt) loopedR.Release() } if r.Loop&stack.PacketOut == 0 { @@ -342,23 +329,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { + if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + ep.HandlePacket(&route, pkt) n++ continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } @@ -370,7 +350,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) @@ -426,35 +406,23 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.IPv4(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv4(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - pkt.NetworkHeader = headerView[:h.HeaderLength()] - - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - pkt.Data.TrimFront(hlen) - pkt.Data.CapLength(tlen - hlen) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } - more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 - if more || h.FragmentOffset() != 0 { - if pkt.Data.Size() == 0 { + if h.More() || h.FragmentOffset() != 0 { + if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -473,7 +441,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } var ready bool var err error - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -485,7 +453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - headerView.CapLength(hlen) + pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -565,6 +533,35 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv4(hdr) + + // If there are options, pull those into hdr as well. + if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(headerLen) + if !ok { + panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) + } + ipHdr = header.IPv4(hdr) + } + + // If this is a fragment, don't bother parsing the transport header. + parseTransportHeader := true + if ipHdr.More() || ipHdr.FragmentOffset() != 0 { + parseTransportHeader = false + } + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return ipHdr.TransportProtocol(), parseTransportHeader, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 36035c820..11e579c4b 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -114,7 +114,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { +func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) @@ -174,7 +174,7 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn type errorChannel struct { *channel.Endpoint - Ch chan stack.PacketBuffer + Ch chan *stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -184,7 +184,7 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan stack.PacketBuffer, size), + Ch: make(chan *stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } @@ -203,7 +203,7 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { select { case e.Ch <- pkt: default: @@ -282,13 +282,17 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := stack.PacketBuffer{ + source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -296,7 +300,7 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []stack.PacketBuffer + var results []*stack.PacketBuffer L: for { select { @@ -338,7 +342,11 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -460,7 +468,7 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), }) } @@ -644,6 +652,18 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload1, udpPayload2}, }, + { + name: "Fragment without followup", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { @@ -698,7 +718,7 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..2ff7eedf4 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -70,17 +70,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { received.Invalid.Increment() return } h := header.ICMPv6(v) - iph := header.IPv6(netHeader) + iph := header.IPv6(pkt.NetworkHeader) // Validate ICMPv6 checksum before processing the packet. // @@ -288,7 +291,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, }); err != nil { sent.Dropped.Increment() @@ -390,7 +393,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: pkt.Data, }); err != nil { @@ -532,7 +535,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..52a01b44e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -57,7 +57,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return nil } @@ -67,7 +67,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -179,36 +179,32 @@ func TestICMPCounts(t *testing.T) { }, } - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + handleIPv6Payload := func(icmp header.ICMPv6) { + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + ep.HandlePacket(&r, &stack.PacketBuffer{ + NetworkHeader: buffer.View(ip), + Data: buffer.View(icmp).ToVectorisedView(), }) } for _, typ := range types { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - handleIPv6Payload(hdr) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -328,7 +324,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ Data: vv, }) } @@ -546,25 +542,22 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } handleIPv6Payload := func(checksum bool) { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(typ.size + extraDataLen), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) } @@ -740,7 +733,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -918,7 +911,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index daf1fcbc6..95fbcf2d1 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -116,7 +116,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -128,7 +128,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ + e.HandlePacket(&loopedR, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -163,30 +163,28 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.IPv6(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv6(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - - pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] - pkt.Data.TrimFront(header.IPv6MinimumSize) - pkt.Data.CapLength(int(h.PayloadLength())) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data) + // vv consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader) + vv.Append(pkt.Data) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false for firstHeader := true; ; firstHeader = false { @@ -262,9 +260,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6FragmentExtHdr: hasFragmentHeader = true - fragmentOffset := extHdr.FragmentOffset() - more := extHdr.More() - if !more && fragmentOffset == 0 { + if extHdr.IsAtomic() { // This fragment extension header indicates that this packet is an // atomic fragment. An atomic fragment is a fragment that contains // all the data required to reassemble a full packet. As per RFC 6946, @@ -277,9 +273,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. - rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */) + rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */) - if fragmentOffset == 0 { + if extHdr.FragmentOffset() == 0 { // Check that the iterator ends with a raw payload as the first fragment // should include all headers up to and including any upper layer // headers, as per RFC 8200 section 4.5; only upper layer data @@ -332,7 +328,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } // The packet is a fragment, let's try to reassemble it. - start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit last := start + uint16(fragmentPayloadLen) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the @@ -345,7 +341,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } var ready bool - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf) + // Note that pkt doesn't have its transport header set after reassembly, + // and won't until DeliverNetworkPacket sets it. + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -394,10 +392,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6RawPayloadHeader: // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + + // For unfragmented packets, extHdr still contains the transport header. + // Get rid of it. + // + // For reassembled fragments, pkt.TransportHeader is unset, so this is a + // no-op and pkt.Data begins with the transport header. + extHdr.Buf.TrimFront(len(pkt.TransportHeader)) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, pkt, hasFragmentHeader) + e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error @@ -505,6 +510,79 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + // + // Parsing occurs again in HandlePacket because we don't track the + // extensions in PacketBuffer. Unfortunately, that means HandlePacket + // has to do the parsing work again. + var nextHdr tcpip.TransportProtocolNumber + foundNext := true + extensionsSize := 0 +traverseExtensions: + for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { + if err != nil { + break + } + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + foundNext = false + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + // If this is an atomic fragment, we don't have to treat it specially. + if !extHdr.More() && extHdr.FragmentOffset() == 0 { + continue + } + // This is a non-atomic fragment and has to be re-assembled before we can + // examine the payload for a transport header. + foundNext = false + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader. + hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + + return nextHdr, foundNext, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 841a0cb7a..213ff64f2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -123,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -637,7 +637,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -1238,7 +1238,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 12b70f7e9..64239ce9a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -380,7 +380,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -497,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -551,25 +551,29 @@ func TestNDPValidation(t *testing.T) { return s, ep, r } - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View if atomicFragment { - bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength) - bytes[0] = nextHdr + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } + ep.HandlePacket(r, &stack.PacketBuffer{ + NetworkHeader: buffer.View(ip), + Data: payload.ToVectorisedView(), }) } @@ -676,14 +680,11 @@ func TestNDPValidation(t *testing.T) { invalid := stats.Invalid typStat := typ.statCounter(stats) - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetCode(test.code) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -699,7 +700,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -884,7 +885,7 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index f71073207..afca925ad 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -110,5 +110,6 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", ], ) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7d1ede1f2..05bf62788 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -20,7 +20,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" @@ -147,46 +146,8 @@ type ConnTrackTable struct { Seed uint32 } -// parseHeaders sets headers in the packet. -func parseHeaders(pkt *PacketBuffer) { - newPkt := pkt.Clone() - - // Set network header. - hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - netHeader := header.IPv4(hdr) - newPkt.NetworkHeader = hdr - length := int(netHeader.HeaderLength()) - - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - // Set transport header. - switch protocol := netHeader.TransportProtocol(); protocol { - case header.UDPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) - } - case header.TCPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) - } - } - pkt.NetworkHeader = newPkt.NetworkHeader - pkt.TransportHeader = newPkt.TransportHeader -} - // packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { +func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { var tuple connTrackTuple netHeader := header.IPv4(pkt.NetworkHeader) @@ -257,15 +218,8 @@ func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { // TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other // transport protocols. func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - if hook == Prerouting { - // Headers will not be set in Prerouting. - // TODO(gvisor.dev/issue/170): Change this after parsing headers - // code is added. - parseHeaders(pkt) - } - var dir ctDirection - tuple, err := packetToTuple(*pkt, hook) + tuple, err := packetToTuple(pkt, hook) if err != nil { return nil, dir } diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 6b64cd37f..3eff141e6 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt PacketBuffer + pkt *PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 8084d50bc..a6546cef0 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -33,6 +33,10 @@ const ( // except where another value is explicitly used. It is chosen to match // the MTU of loopback interfaces on linux systems. fwdTestNetDefaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -68,16 +72,9 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { - // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fwdTestNetHeaderLen) - +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -96,13 +93,13 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + b[dstAddrOffset] = r.RemoteAddress[0] + b[srcAddrOffset] = f.id.LocalAddress[0] + b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) } @@ -112,7 +109,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -140,7 +137,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { } func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) +} + +func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = netHeader + pkt.Data.TrimFront(fwdTestNetHeaderLen) + return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -190,7 +197,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt PacketBuffer + Pkt *PacketBuffer } type fwdTestLinkEndpoint struct { @@ -203,13 +210,13 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -251,7 +258,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -270,7 +277,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, *pkt) + e.WritePacket(r, gso, protocol, pkt) n++ } @@ -280,7 +287,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: PacketBuffer{Data: vv}, + Pkt: &PacketBuffer{Data: vv}, } select { @@ -361,8 +368,8 @@ func TestForwardingWithStaticResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -398,8 +405,8 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -429,8 +436,8 @@ func TestForwardingWithNoResolver(t *testing.T) { // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -459,16 +466,16 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. buf := buffer.NewView(30) - buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -480,9 +487,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -509,8 +515,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -524,9 +530,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -554,10 +559,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -571,14 +576,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). + if !ok { + t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) } - // The first 5 packets should not be forwarded so the the - // sequemnce number should start with 5. + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. want := uint16(i + 5) - if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { t.Fatalf("got the packet #%d, want = #%d", n, want) } @@ -609,8 +618,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Each packet has a different destination address (3 to // maxPendingResolutions + 7). buf := buffer.NewView(30) - buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -626,9 +635,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() - if b[0] < 8 { - t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 443423b3c..4e9b404c8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -44,11 +43,11 @@ const HookUnset = -1 // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { +func DefaultTables() *IPTables { // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for // iotas. - return IPTables{ - Tables: map[string]Table{ + return &IPTables{ + tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -107,7 +106,7 @@ func DefaultTables() IPTables { UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ + priorities: map[Hook][]string{ Input: []string{TablenameNat, TablenameFilter}, Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, @@ -159,6 +158,36 @@ func EmptyNatTable() Table { } } +// GetTable returns table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + it.mu.RLock() + defer it.mu.RUnlock() + t, ok := it.tables[name] + return t, ok +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) { + it.mu.Lock() + defer it.mu.Unlock() + it.tables[name] = table +} + +// ModifyTables acquires write-lock and calls fn with internal name-to-table +// map. This function can be used to update multiple tables atomically. +func (it *IPTables) ModifyTables(fn func(map[string]Table)) { + it.mu.Lock() + defer it.mu.Unlock() + fn(it.tables) +} + +// GetPriorities returns slice of priorities associated with hook. +func (it *IPTables) GetPriorities(hook Hook) []string { + it.mu.RLock() + defer it.mu.RUnlock() + return it.priorities[hook] +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -185,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr it.connections.HandlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + for _, tablename := range it.GetPriorities(hook) { + table, _ := it.GetTable(tablename) ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. @@ -314,7 +343,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -322,7 +351,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, *pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { return RuleDrop, 0 } @@ -335,47 +364,3 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // All the matchers matched, so run the target. return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } - -func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - // Check the transport protocol. - if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { - return false - } - - // Check the destination IP. - dest := hdr.DestinationAddress() - matches := true - for i := range filter.Dst { - if dest[i]&filter.DstMask[i] != filter.Dst[i] { - matches = false - break - } - } - if matches == filter.DstInvert { - return false - } - - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(filter.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which begins - // with the name should be matched. - ifName := filter.OutputInterface - matches = true - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } else { - matches = nicName == ifName - } - return filter.OutputInterfaceInvert != matches - } - - return true -} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 36cc6275d..92e31643e 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -98,11 +98,6 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set network header. - if hook == Prerouting { - parseHeaders(pkt) - } - // Drop the packet if network and transport header are not set. if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fe06007ae..4a6a5c6f1 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,7 +15,11 @@ package stack import ( + "strings" + "sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // A Hook specifies one of the hooks built into the network stack. @@ -75,13 +79,17 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables and priorities. + mu sync.RWMutex + + // tables maps table names to tables. User tables have arbitrary names. mu + // needs to be locked for accessing. + tables map[string]Table - // Priorities maps each hook to a list of table names. The order of the + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities map[Hook][]string connections ConnTrackTable } @@ -159,6 +167,16 @@ type IPHeaderFilter struct { // comparison. DstInvert bool + // Src matches the source IP address. + Src tcpip.Address + + // SrcMask masks bits of the source IP address when comparing with Src. + SrcMask tcpip.Address + + // SrcInvert inverts the meaning of the source IP check, i.e. when true the + // filter will match packets that fail the source comparison. + SrcInvert bool + // OutputInterface matches the name of the outgoing interface for the // packet. OutputInterface string @@ -173,6 +191,55 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } +// match returns whether hdr matches the filter. +func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the source and destination IPs. + if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + return false + } + + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(fl.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := fl.OutputInterface + matches := true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return fl.OutputInterfaceInvert != matches + } + + return true +} + +// filterAddress returns whether addr matches the filter. +func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { + matches := true + for i := range filterAddr { + if addr[i]&mask[i] != filterAddr[i] { + matches = false + break + } + } + return matches != invert +} + // A Matcher is the interface for matching packets. type Matcher interface { // Name returns the name of the Matcher. @@ -183,7 +250,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 526c7d6ff..e28c23d66 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -467,8 +467,17 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - // The timer used to send the next router solicitation message. - rtrSolicitTimer *time.Timer + rtrSolicit struct { + // The timer used to send the next router solicitation message. + timer *time.Timer + + // Used to let the Router Solicitation timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this ndpState is associated with. MUST be set when the timer is + // set. + done *bool + } // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -648,13 +657,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. timer = time.AfterFunc(0, func() { - ndp.nic.mu.RLock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + if done { // If we reach this point, it means that the DAD timer fired after // another goroutine already obtained the NIC lock and stopped DAD // before this function obtained the NIC lock. Simply return here and do // nothing further. - ndp.nic.mu.RUnlock() return } @@ -665,15 +675,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } dadDone := remaining == 0 - ndp.nic.mu.RUnlock() var err *tcpip.Error if !dadDone { - err = ndp.sendDADPacket(addr) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + + // Do not hold the lock when sending packets which may be a long running + // task or may block link address resolution. We know this is safe + // because immediately after obtaining the lock again, we check if DAD + // has been stopped before doing any work with the NIC. Note, DAD would be + // stopped if the NIC was disabled or removed, or if the address was + // removed. + ndp.nic.mu.Unlock() + err = ndp.sendDADPacket(addr, ref) + ndp.nic.mu.Lock() } - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. @@ -721,17 +739,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // addr. // // addr must be a tentative IPv6 address on ndp's NIC. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { +// +// The NIC ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { + return err + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) @@ -750,7 +775,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -1816,7 +1841,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicitTimer != nil { + if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. return } @@ -1833,14 +1858,27 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { + var done bool + ndp.rtrSolicit.done = &done + ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that the RS timer fired after another + // goroutine already obtained the NIC lock and stopped solicitations. + // Simply return here and do nothing further. + ndp.nic.mu.Unlock() + return + } + // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) if ref == nil { - ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) } + ndp.nic.mu.Unlock() + localAddr := ref.ep.ID().LocalAddress r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() @@ -1849,6 +1887,13 @@ func (ndp *ndpState) startSolicitingRouters() { // header.IPv6AllRoutersMulticastAddress is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { + return + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) @@ -1881,7 +1926,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) @@ -1893,17 +1938,18 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - if remaining == 0 { - ndp.rtrSolicitTimer = nil - } else if ndp.rtrSolicitTimer != nil { + if done || remaining == 0 { + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil + } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but // we still reached this point, then we know the NIC // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. - ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval) + ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } + ndp.nic.mu.Unlock() }) } @@ -1913,13 +1959,15 @@ func (ndp *ndpState) startSolicitingRouters() { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicitTimer == nil { + if ndp.rtrSolicit.timer == nil { // Nothing to do. return } - ndp.rtrSolicitTimer.Stop() - ndp.rtrSolicitTimer = nil + *ndp.rtrSolicit.done = true + ndp.rtrSolicit.timer.Stop() + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil } // initializeTempAddrState initializes state related to temporary SLAAC diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b3d174cdd..58f1ebf60 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -613,7 +613,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -935,7 +935,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -970,14 +970,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -986,7 +986,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -994,7 +994,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -1003,7 +1003,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 54103fdb3..644c0d437 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -457,8 +457,20 @@ type ipv6AddrCandidate struct { // remoteAddr must be a valid IPv6 address. func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { n.mu.RLock() - defer n.mu.RUnlock() + ref := n.primaryIPv6EndpointRLocked(remoteAddr) + n.mu.RUnlock() + return ref +} +// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 and 7 are followed. +// +// remoteAddr must be a valid IPv6 address. +// +// n.mu MUST be read locked. +func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { @@ -568,11 +580,6 @@ const ( // promiscuous indicates that the NIC's promiscuous flag should be observed // when getting a NIC's referenced network endpoint. promiscuous - - // forceSpoofing indicates that the NIC should be assumed to be spoofing, - // regardless of what the NIC's spoofing flag is when getting a NIC's - // referenced network endpoint. - forceSpoofing ) func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { @@ -591,8 +598,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { - id := NetworkEndpointID{address} - n.mu.RLock() var spoofingOrPromiscuous bool @@ -601,11 +606,9 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t spoofingOrPromiscuous = n.mu.spoofing case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous - case forceSpoofing: - spoofingOrPromiscuous = true } - if ref, ok := n.mu.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // An endpoint with this id exists, check if it can be used and return it. switch ref.getKind() { case permanentExpired: @@ -654,11 +657,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.mu.endpoints[id]; ok { + ref := n.getRefOrCreateTempLocked(protocol, address, peb) + n.mu.Unlock() + return ref +} + +/// getRefOrCreateTempLocked returns an existing endpoint for address or creates +/// and returns a temporary endpoint. +func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { - n.mu.Unlock() return ref } // tryIncRef failing means the endpoint is scheduled to be removed once the @@ -670,7 +680,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - n.mu.Unlock() return nil } ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ @@ -681,7 +690,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t }, }, peb, temporary, static, false) - n.mu.Unlock() return ref } @@ -1153,7 +1161,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1167,7 +1175,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1212,12 +1220,21 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + // Parse headers. + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { + // The packet is too small to contain a network header. n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + if hasTransportHdr { + // Parse the transport header if present. + if state, ok := n.stack.transportProtocols[transProtoNum]; ok { + state.proto.Parse(pkt) + } + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1229,18 +1246,19 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. - if protocol == header.IPv4ProtocolNumber { + // Loopback traffic skips the prerouting chain. + if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { + if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } } if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) + handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) return } @@ -1298,24 +1316,37 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this + // by having lower layers explicity write each header instead of just + // pkt.Header. + + // pkt may have set its NetworkHeader and TransportHeader. If we're + // forwarding, we'll have to copy them into pkt.Header. + pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) + if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) + } + if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1329,13 +1360,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // from fragments. + if pkt.TransportHeader == nil { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + pkt.TransportHeader = transHeader + } else { + // This is either a bad packet or was re-assembled from fragments. + transProto.Parse(pkt) + } + } + + if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,7 +1411,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index d672fc157..31f865260 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,268 @@ package stack import ( + "math" "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) +var _ LinkEndpoint = (*testLinkEndpoint)(nil) + +// A LinkEndpoint that throws away outgoing packets. +// +// We use this instead of the channel endpoint as the channel package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testLinkEndpoint struct { + dispatcher NetworkDispatcher +} + +// Attach implements LinkEndpoint.Attach. +func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements LinkEndpoint.IsAttached. +func (e *testLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements LinkEndpoint.MTU. +func (*testLinkEndpoint) MTU() uint32 { + return math.MaxUint16 +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { + return CapabilityResolutionRequired +} + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*testLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Wait implements LinkEndpoint.Wait. +func (*testLinkEndpoint) Wait() {} + +// WritePacket implements LinkEndpoint.WritePacket. +func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) + +// An IPv6 NetworkEndpoint that throws away outgoing packets. +// +// We use this instead of ipv6.endpoint because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Endpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + linkEP LinkEndpoint + protocol *testIPv6Protocol +} + +// DefaultTTL implements NetworkEndpoint.DefaultTTL. +func (*testIPv6Endpoint) DefaultTTL() uint8 { + return 0 +} + +// MTU implements NetworkEndpoint.MTU. +func (e *testIPv6Endpoint) MTU() uint32 { + return e.linkEP.MTU() - header.IPv6MinimumSize +} + +// Capabilities implements NetworkEndpoint.Capabilities. +func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. +func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket implements NetworkEndpoint.WritePacket. +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements NetworkEndpoint.WritePackets. +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteHeaderIncludedPacket implements +// NetworkEndpoint.WriteHeaderIncludedPacket. +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// ID implements NetworkEndpoint.ID. +func (e *testIPv6Endpoint) ID() *NetworkEndpointID { + return &e.id +} + +// PrefixLen implements NetworkEndpoint.PrefixLen. +func (e *testIPv6Endpoint) PrefixLen() int { + return e.prefixLen +} + +// NICID implements NetworkEndpoint.NICID. +func (e *testIPv6Endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// HandlePacket implements NetworkEndpoint.HandlePacket. +func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { +} + +// Close implements NetworkEndpoint.Close. +func (*testIPv6Endpoint) Close() {} + +// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. +func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +var _ NetworkProtocol = (*testIPv6Protocol)(nil) + +// An IPv6 NetworkProtocol that supports the bare minimum to make a stack +// believe it supports IPv6. +// +// We use this instead of ipv6.protocol because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Protocol struct{} + +// Number implements NetworkProtocol.Number. +func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. +func (*testIPv6Protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. +func (*testIPv6Protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint implements NetworkProtocol.NewEndpoint. +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &testIPv6Endpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + protocol: p, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { + return nil +} + +// Option implements NetworkProtocol.Option. +func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. +func (*testIPv6Protocol) Close() {} + +// Wait implements NetworkProtocol.Wait. +func (*testIPv6Protocol) Wait() {} + +// Parse implements NetworkProtocol.Parse. +func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + return 0, false, false +} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) + +// LinkAddressProtocol implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { + return nil +} + +// ResolveStaticAddress implements LinkAddressResolver. +func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if header.IsV6MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv6Address(addr), true + } + return "", false +} + +// Test the race condition where a NIC is removed and an RS timer fires at the +// same time. +func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { + const ( + nicID = 1 + + maxRtrSolicitations = 5 + ) + + e := testLinkEndpoint{} + s := New(Options{ + NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NDPConfigs: NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: minimumRtrSolicitationInterval, + }, + }) + + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) + } + + s.mu.Lock() + // Wait for the router solicitation timer to fire and block trying to obtain + // the stack lock when doing link address resolution. + time.Sleep(minimumRtrSolicitationInterval * 2) + if err := s.removeNICLocked(nicID); err != nil { + t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) + } + s.mu.Unlock() +} + func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. @@ -44,7 +301,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 926df4d7b..1b5da6017 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -24,6 +24,8 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { + _ noCopy + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -82,7 +84,32 @@ type PacketBuffer struct { // VectorisedView but does not deep copy the underlying bytes. // // Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk +// +// FIXME(b/153685824): Data gets copied but not other header references. +func (pk *PacketBuffer) Clone() *PacketBuffer { + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + Header: pk.Header, + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + } } + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b331427c6..5cbc946b6 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -168,6 +168,11 @@ type TransportProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does + // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() < + // MinimumPacketSize() + Parse(pkt *PacketBuffer) (ok bool) } // TransportDispatcher contains the methods used by the network stack to deliver @@ -180,7 +185,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +194,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -240,17 +245,18 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have - // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error + // protocol. It takes ownership of pkt. pkt.TransportHeader must have already + // been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. pkts must not be zero length. + // protocol. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network - // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error + // header to the given destination address. It takes ownership of pkt. + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -265,7 +271,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -312,6 +318,14 @@ type NetworkProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It + // returns: + // - The encapsulated protocol, if present. + // - Whether there is an encapsulated transport protocol payload (e.g. ARP + // does not encapsulate anything). + // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. + Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) } // NetworkDispatcher contains the methods used by the network stack to deliver @@ -326,7 +340,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -382,17 +396,17 @@ type LinkEndpoint interface { LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the - // given route. It sets pkt.LinkHeader if a link layer header exists. - // pkt.NetworkHeader and pkt.TransportHeader must have already been - // set. + // given route. It takes ownership of pkt. pkt.NetworkHeader and + // pkt.TransportHeader must have already been set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the - // given route. pkts must not be zero length. + // given route. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may @@ -400,7 +414,7 @@ type LinkEndpoint interface { WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. + // should already have an ethernet header. It takes ownership of vv. WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer @@ -430,7 +444,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 150297ab9..d65f8049e 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -113,6 +113,8 @@ func (r *Route) GSOMaxSize() uint32 { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). +// +// The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or @@ -148,22 +150,27 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. +// +// The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } return err } @@ -175,9 +182,12 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } + // WritePackets takes ownership of pkt, calculate length first. + numPkts := pkts.Len() + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) @@ -193,17 +203,20 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } + // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Data.Size() + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0ab4c3e19..648791302 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -52,7 +52,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -424,12 +424,8 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -676,6 +672,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, @@ -778,7 +775,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1020,6 +1017,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() + return s.removeNICLocked(id) +} + +// removeNICLocked removes NIC and all related routes from the network stack. +// +// s.mu must be locked. +func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { nic, ok := s.nics[id] if !ok { return tcpip.ErrUnknownNICID @@ -1741,18 +1745,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1a2cf007c..ffef9bc2c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -52,6 +52,10 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -90,30 +94,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fakeNetHeaderLen) - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) + f.dispatcher.DeliverTransportControlPacket( + tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + fakeNetNumber, + tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -132,24 +134,19 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) + pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] + pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] + pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + f.HandlePacket(r, pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -163,7 +160,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -205,7 +202,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { } func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { @@ -247,6 +244,17 @@ func (*fakeNetworkProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(fakeNetHeaderLen) + return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true +} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -292,8 +300,8 @@ func TestNetworkReceive(t *testing.T) { buf := buffer.NewView(30) // Make sure packet with wrong address is not delivered. - buf[0] = 3 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 3 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -304,8 +312,8 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to first endpoint. - buf[0] = 1 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 1 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -316,8 +324,8 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to second endpoint. - buf[0] = 2 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 2 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -328,7 +336,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -340,7 +348,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -362,7 +370,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -420,7 +428,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -982,7 +990,7 @@ func TestAddressRemoval(t *testing.T) { buf := buffer.NewView(30) // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -1032,7 +1040,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) @@ -1114,7 +1122,7 @@ func TestEndpointExpiration(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte if promiscuous { if err := s.SetPromiscuousMode(nicID, true); err != nil { @@ -1277,7 +1285,7 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. @@ -1658,7 +1666,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -1766,7 +1774,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -2263,7 +2271,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2344,8 +2352,8 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) - buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9a33ed375..e09866405 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -251,7 +251,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -379,7 +379,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -470,7 +470,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -520,7 +520,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -544,7 +544,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2474a7db3..67d778137 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -127,7 +127,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -165,7 +165,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..ad61c09d6 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -83,12 +83,13 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) + hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) + hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -215,7 +216,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -289,7 +290,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -324,6 +325,17 @@ func (*fakeTransportProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeTransportProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) + if !ok { + return false + } + pkt.TransportHeader = hdr + pkt.Data.TrimFront(fakeTransHeaderLen) + return true +} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } @@ -369,7 +381,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -380,7 +392,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -391,7 +403,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -446,7 +458,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -457,7 +469,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -468,7 +480,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -623,7 +635,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: req.ToVectorisedView(), }) @@ -642,11 +654,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.NetworkHeader[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.NetworkHeader[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 45e930ad8..b7b227328 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -110,6 +110,71 @@ var ( ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ) +var messageToError map[string]*Error + +var populate sync.Once + +// StringToError converts an error message to the error. +func StringToError(s string) *Error { + populate.Do(func() { + var errors = []*Error{ + ErrUnknownProtocol, + ErrUnknownNICID, + ErrUnknownDevice, + ErrUnknownProtocolOption, + ErrDuplicateNICID, + ErrDuplicateAddress, + ErrNoRoute, + ErrBadLinkEndpoint, + ErrAlreadyBound, + ErrInvalidEndpointState, + ErrAlreadyConnecting, + ErrAlreadyConnected, + ErrNoPortAvailable, + ErrPortInUse, + ErrBadLocalAddress, + ErrClosedForSend, + ErrClosedForReceive, + ErrWouldBlock, + ErrConnectionRefused, + ErrTimeout, + ErrAborted, + ErrConnectStarted, + ErrDestinationRequired, + ErrNotSupported, + ErrQueueSizeNotSupported, + ErrNotConnected, + ErrConnectionReset, + ErrConnectionAborted, + ErrNoSuchFile, + ErrInvalidOptionValue, + ErrNoLinkAddress, + ErrBadAddress, + ErrNetworkUnreachable, + ErrMessageTooLong, + ErrNoBufferSpace, + ErrBroadcastDisabled, + ErrNotPermitted, + ErrAddressFamilyNotSupported, + } + + messageToError = make(map[string]*Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} + // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 2f98a996f..7f172f978 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.15 +// +build !go1.16 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..57e0a069b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -140,11 +140,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -450,7 +445,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), @@ -481,7 +476,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: dataVV, TransportHeader: buffer.View(icmpv6), @@ -511,6 +506,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicID := addr.NIC localPort := uint16(0) switch e.state { + case stateInitial: case stateBound, stateConnected: localPort = e.ID.LocalPort if e.BindNICID == 0 { @@ -743,7 +739,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -805,7 +801,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3c47692b2..74ef6541e 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -124,6 +124,16 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. + // + // Right now, the Parse() method is tied to enabled protocols passed into + // stack.New. This works for UDP and TCP, but we handle ICMP traffic even + // when netstack users don't pass ICMP as a supported protocol. + return false +} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 23158173d..baf08eda6 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -132,11 +132,6 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (stack.IPTables, error) { - return ep.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() @@ -298,7 +293,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eee754a5a..a406d815e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !e.associated { @@ -348,7 +343,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{ + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { return 0, nil, err @@ -357,7 +352,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), Owner: e.owner, @@ -584,7 +579,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -632,8 +627,9 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { }, } - networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) - combinedVV := networkHeader.ToVectorisedView() + headers := append(buffer.View(nil), pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV := headers.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV packet.timestampNS = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f38eb6833..e26f01fae 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -86,10 +86,6 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], - # FIXME(b/68809571) - tags = [ - "flaky", - ], deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a7e088d4e..7da93dcc4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -833,13 +833,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), Data: data, Hash: tf.txHash, Owner: owner, } - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { tf.ttl = r.DefaultTTL() @@ -1347,6 +1347,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.setEndpointState(StateError) e.HardError = err + e.workerCleanup = true // Lock released below. epilogue() return err diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 6062ca916..047704c80 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -186,7 +186,7 @@ func (d *dispatcher) wait() { } } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) if !s.parse() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b5ba972f1..19f7bf449 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,7 +63,8 @@ const ( StateClosing ) -// connected is the set of states where an endpoint is connected to a peer. +// connected returns true when s is one of the states representing an +// endpoint connected to a peer. func (s EndpointState) connected() bool { switch s { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: @@ -73,6 +74,40 @@ func (s EndpointState) connected() bool { } } +// connecting returns true when s is one of the states representing a +// connection in progress, but not yet fully established. +func (s EndpointState) connecting() bool { + switch s { + case StateConnecting, StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// handshake returns true when s is one of the states representing an endpoint +// in the middle of a TCP handshake. +func (s EndpointState) handshake() bool { + switch s { + case StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// closed returns true when s is one of the states an endpoint transitions to +// when closed or when it encounters an error. This is distinct from a newly +// initialized endpoint that was never connected. +func (s EndpointState) closed() bool { + switch s { + case StateClose, StateError: + return true + default: + return false + } +} + // String implements fmt.Stringer.String. func (s EndpointState) String() string { switch s { @@ -1172,11 +1207,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -2462,7 +2492,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -2481,7 +2511,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 8b7562396..cbb779666 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.EndpointState() { - case StateInitial, StateBound: - // TODO(b/138137272): this enumeration duplicates - // EndpointState.connected. remove it. - case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + epState := e.EndpointState() + switch { + case epState == StateInitial || epState == StateBound: + case epState.connected() || epState.handshake(): if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) @@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() { break } fallthrough - case StateListen, StateConnecting: + case epState == StateListen || epState == StateConnecting: e.drainSegmentLocked() - if e.EndpointState() != StateClose && e.EndpointState() != StateError { + // Refresh epState, since drainSegmentLocked may have changed it. + epState = e.EndpointState() + if !epState.closed() { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } - break } - case StateError, StateClose: + case epState.closed(): for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) @@ -148,23 +148,23 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state EndpointState) { +func (e *endpoint) loadState(epState EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. // For restore purposes we treat TimeWait like a connected endpoint. - if state.connected() || state == StateTimeWait { + if epState.connected() || epState == StateTimeWait { connectedLoading.Add(1) } - switch state { - case StateListen: + switch { + case epState == StateListen: listenLoading.Add(1) - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): connectingLoading.Add(1) } // Directly update the state here rather than using e.setEndpointState // as the endpoint is still being loaded and the stack reference is not // yet initialized. - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) } // afterLoad is invoked by stateify. @@ -183,8 +183,8 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - state := e.origEndpointState - switch state { + epState := e.origEndpointState + switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -208,8 +208,8 @@ func (e *endpoint) Resume(s *stack.Stack) { } } - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + switch { + case epState.connected(): bind() if len(e.connectingAddress) == 0 { e.connectingAddress = e.ID.RemoteAddress @@ -232,13 +232,13 @@ func (e *endpoint) Resume(s *stack.Stack) { closed := e.closed e.mu.Unlock() e.notifyProtocolGoroutine(notifyTickleWorker) - if state == StateFinWait2 && closed { + if epState == StateFinWait2 && closed { // If the endpoint has been closed then make sure we notify so // that the FIN_WAIT2 timer is started after a restore. e.notifyProtocolGoroutine(notifyClose) } connectedLoading.Done() - case StateListen: + case epState == StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -255,7 +255,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -267,7 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case StateBound: + case epState == StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -276,7 +276,7 @@ func (e *endpoint) Resume(s *stack.Stack) { bind() tcpip.AsyncLoading.Done() }() - case StateClose: + case epState == StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -291,12 +291,11 @@ func (e *endpoint) Resume(s *stack.Stack) { e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) - case StateError: + case epState == StateError: e.state = StateError e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } - } // saveLastError is invoked by stateify. @@ -314,7 +313,7 @@ func (e *endpoint) loadLastError(s string) { return } - e.lastError = loadError(s) + e.lastError = tcpip.StringToError(s) } // saveHardError is invoked by stateify. @@ -332,71 +331,7 @@ func (e *EndpointInfo) loadHardError(s string) { return } - e.HardError = loadError(s) -} - -var messageToError map[string]*tcpip.Error - -var populate sync.Once - -func loadError(s string) *tcpip.Error { - populate.Do(func() { - var errors = []*tcpip.Error{ - tcpip.ErrUnknownProtocol, - tcpip.ErrUnknownNICID, - tcpip.ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID, - tcpip.ErrDuplicateAddress, - tcpip.ErrNoRoute, - tcpip.ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound, - tcpip.ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected, - tcpip.ErrNoPortAvailable, - tcpip.ErrPortInUse, - tcpip.ErrBadLocalAddress, - tcpip.ErrClosedForSend, - tcpip.ErrClosedForReceive, - tcpip.ErrWouldBlock, - tcpip.ErrConnectionRefused, - tcpip.ErrTimeout, - tcpip.ErrAborted, - tcpip.ErrConnectStarted, - tcpip.ErrDestinationRequired, - tcpip.ErrNotSupported, - tcpip.ErrQueueSizeNotSupported, - tcpip.ErrNotConnected, - tcpip.ErrConnectionReset, - tcpip.ErrConnectionAborted, - tcpip.ErrNoSuchFile, - tcpip.ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress, - tcpip.ErrBadAddress, - tcpip.ErrNetworkUnreachable, - tcpip.ErrMessageTooLong, - tcpip.ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled, - tcpip.ErrNotPermitted, - tcpip.ErrAddressFamilyNotSupported, - } - - messageToError = make(map[string]*tcpip.Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e + e.HardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 704d01c64..070b634b4 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a2a7ddeb..73b8a6782 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "fmt" "runtime" "strings" "time" @@ -206,7 +207,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { p.dispatcher.queuePacket(r, ep, id, pkt) } @@ -217,7 +218,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() @@ -490,6 +491,26 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { return &p.synRcvdCount } +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(offset) + if !ok { + panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) + } + } + + pkt.TransportHeader = hdr + pkt.Data.TrimFront(len(hdr)) + return true +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 074edded6..0280892a8 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -35,6 +35,7 @@ type segment struct { id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -60,13 +61,14 @@ type segment struct { xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) + s.hdr = header.TCP(pkt.TransportHeader) s.rcvdTime = time.Now() return s } @@ -146,12 +148,6 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) - // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -162,16 +158,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(s.hdr.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(s.hdr) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -180,22 +172,19 @@ func (s *segment) parse() bool { if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { s.csumValid = true verifyChecksum = false - s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) - s.data.TrimFront(offset) + s.csum = s.hdr.Checksum() + xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(s.hdr.AckNumber()) + s.flags = s.hdr.Flags() + s.window = seqnum.Size(s.hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 06dc9b7d7..acacb42e4 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -618,6 +618,20 @@ func (s *sender) splitSeg(seg *segment, size int) { nSeg.data.TrimFront(size) nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) s.writeList.InsertAfter(seg, nSeg) + + // The segment being split does not carry PUSH flag because it is + // followed by the newly split segment. + // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered + // segment (i.e., when there is no more queued data to be sent). + // Linux removes PSH flag only when the segment is being split over MSS + // and retains it when we are splitting the segment over lack of sender + // window space. + // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() + if seg.data.Size() > s.maxPayloadSize { + seg.flags ^= header.TCPFlagPsh + } + seg.data.CapLength(size) } @@ -739,7 +753,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(seg.sequenceNumber.Size(end)) + available := int(s.sndNxt.Size(end)) if available > limit { available = limit } @@ -782,8 +796,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // sent all at once. return false } - if atomic.LoadUint32(&s.ep.cork) != 0 { - // Hold back the segment until full. + // With TCP_CORK, hold back until minimum of the available + // send space and MSS. + // TODO(gvisor.dev/issue/2833): Drain the held segments after a + // timeout. + if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { return false } } @@ -816,6 +833,25 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se panic("Netstack queues FIN segments without data.") } + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + // If the entire segment cannot be accomodated in the receiver + // advertized window, skip splitting and sending of the segment. + // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered + // by a probe timer. On this condition, it defers the segment + // split and transmit to a short probe timer. + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split the + // segment right here if there are no pending segments. + // If there are pending segments, segment transmits are deferred + // to the retransmit timer handler. + if s.sndUna != s.sndNxt && !segEnd.LessThan(end) { + return false + } + if !seg.sequenceNumber.LessThan(end) { return false } @@ -824,9 +860,17 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if available == 0 { return false } + + // The segment size limit is computed as a function of sender congestion + // window and MSS. When sender congestion window is > 1, this limit can + // be larger than MSS. Ensure that the currently available send space + // is not greater than minimum of this limit and MSS. if available > limit { available = limit } + if available > s.maxPayloadSize { + available = s.maxPayloadSize + } if seg.data.Size() > available { s.splitSeg(seg, available) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 7b1d72cf4..9721f6caf 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -316,7 +316,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -372,7 +372,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: s, }) } @@ -380,7 +380,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegment(payload, h), }) } @@ -389,7 +389,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) } @@ -564,7 +564,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..c5e3c73ef 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,6 +15,7 @@ package udp import ( + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -106,6 +107,9 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID @@ -188,6 +192,15 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + err := e.lastError + e.lastError = nil + return err +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -235,14 +248,13 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + e.rcvMu.Lock() if e.rcvList.Empty() { @@ -382,6 +394,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return 0, nil, err + } + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -410,24 +426,33 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } var route *stack.Route + var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) var dstPort uint16 if to == nil { route = &e.route dstPort = e.dstPort - - if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, given that it may need to - // change in Route.Resolve() call below. + resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { + // Promote lock to exclusive if using a shared route, given that it may + // need to change in Route.Resolve() call below. e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() // Recheck state after lock was re-acquired. if e.state != StateConnected { - return 0, nil, tcpip.ErrInvalidEndpointState + err = tcpip.ErrInvalidEndpointState + } + if err == nil && route.IsResolutionRequired() { + ch, err = route.Resolve(waker) + } + + e.mu.Unlock() + e.mu.RLock() + + // Recheck state after lock was re-acquired. + if e.state != StateConnected { + err = tcpip.ErrInvalidEndpointState } + return } } else { // Reject destination address if it goes through a different @@ -458,10 +483,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c route = &r dstPort = dst.Port + resolve = route.Resolve } if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { + if ch, err := resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { return 0, ch, tcpip.ErrNoLinkAddress } @@ -853,6 +879,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: + return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -900,7 +927,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: ProtocolNumber, + TTL: ttl, + TOS: tos, + }, &stack.PacketBuffer{ Header: hdr, Data: data, TransportHeader: buffer.View(udp), @@ -1248,18 +1279,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - pkt.Data.TrimFront(header.UDPMinimumSize) - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() @@ -1315,7 +1344,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { + if typ == stack.ControlPortUnreachable { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state == StateConnected { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + e.lastError = tcpip.ErrConnectionRefused + } + } } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 466bd9381..851e6b635 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,6 +37,24 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = tcpip.StringToError(s) +} + // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a674ceb68..7abfa0ed2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, @@ -61,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - pkt stack.PacketBuffer + pkt *stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..4218e7d03 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,15 +66,9 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { - // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -121,7 +115,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -130,9 +124,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) - payload := newNetHeader.ToVectorisedView() - payload.Append(pkt.Data.ToView().ToVectorisedView()) + newHeader := append(buffer.View(nil), pkt.NetworkHeader...) + newHeader = append(newHeader, pkt.TransportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -140,9 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload, + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) case header.IPv6AddressSize: @@ -164,11 +160,11 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) payload.Append(pkt.Data) payload.CapLength(payloadLen) @@ -177,9 +173,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload, + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) } return true @@ -201,6 +198,18 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Packet is too small + return false + } + pkt.TransportHeader = h + pkt.Data.TrimFront(header.UDPMinimumSize) + return true +} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 8acaa607a..313a3f117 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -440,10 +440,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), }) } @@ -487,10 +485,8 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), }) } @@ -1720,6 +1716,58 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { } } +// TestShortHeader verifies that when a packet with a too-short UDP header is +// received, the malformed received global stat gets incremented. +func TestShortHeader(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + + // Allocate a buffer for an IPv6 and too-short UDP header. + const udpSize = header.UDPMinimumSize - 1 + buf := buffer.NewView(header.IPv6MinimumSize + udpSize) + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + + // Initialize the UDP header. + udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: header.UDPMinimumSize, + }) + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + // Copy all but the last byte of the UDP header into the packet. + copy(buf[header.IPv6MinimumSize:], udpHdr) + + // Inject packet. + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 5f2af9f3b..c45d2ecbc 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -148,6 +148,62 @@ func (m MountMode) String() string { panic(fmt.Sprintf("invalid mode: %d", m)) } +// DockerNetwork contains the name of a docker network. +type DockerNetwork struct { + logger testutil.Logger + Name string + Subnet *net.IPNet + containers []*Docker +} + +// NewDockerNetwork sets up the struct for a Docker network. Names of networks +// will be unique. +func NewDockerNetwork(logger testutil.Logger) *DockerNetwork { + return &DockerNetwork{ + logger: logger, + Name: testutil.RandomID(logger.Name()), + } +} + +// Create calls 'docker network create'. +func (n *DockerNetwork) Create(args ...string) error { + a := []string{"docker", "network", "create"} + if n.Subnet != nil { + a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet)) + } + a = append(a, args...) + a = append(a, n.Name) + return testutil.Command(n.logger, a...).Run() +} + +// Connect calls 'docker network connect' with the arguments provided. +func (n *DockerNetwork) Connect(container *Docker, args ...string) error { + a := []string{"docker", "network", "connect"} + a = append(a, args...) + a = append(a, n.Name, container.Name) + if err := testutil.Command(n.logger, a...).Run(); err != nil { + return err + } + n.containers = append(n.containers, container) + return nil +} + +// Cleanup cleans up the docker network and all the containers attached to it. +func (n *DockerNetwork) Cleanup() error { + for _, c := range n.containers { + // Don't propagate the error, it might be that the container + // was already cleaned up. + if err := c.Kill(); err != nil { + n.logger.Logf("unable to kill container during cleanup: %s", err) + } + } + + if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil { + return err + } + return nil +} + // Docker contains the name and the runtime of a docker container. type Docker struct { logger testutil.Logger @@ -162,9 +218,13 @@ type Docker struct { // // Names of containers will be unique. func MakeDocker(logger testutil.Logger) *Docker { + // Slashes are not allowed in container names. + name := testutil.RandomID(logger.Name()) + name = strings.ReplaceAll(name, "/", "-") + return &Docker{ logger: logger, - Name: testutil.RandomID(logger.Name()), + Name: name, Runtime: *runtime, } } @@ -309,7 +369,9 @@ func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) { rv = append(rv, d.Name) } else { rv = append(rv, d.mounts...) - rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) + if len(d.Runtime) > 0 { + rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) + } rv = append(rv, fmt.Sprintf("--name=%s", d.Name)) rv = append(rv, testutil.ImageByName(r.Image)) } @@ -477,6 +539,56 @@ func (d *Docker) FindIP() (net.IP, error) { return ip, nil } +// A NetworkInterface is container's network interface information. +type NetworkInterface struct { + IPv4 net.IP + MAC net.HardwareAddr +} + +// ListNetworks returns the network interfaces of the container, keyed by +// Docker network name. +func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) { + const format = `{{json .NetworkSettings.Networks}}` + out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() + if err != nil { + return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err) + } + + networks := map[string]map[string]string{} + if err := json.Unmarshal(out, &networks); err != nil { + return nil, fmt.Errorf("error decoding network interfaces: %w", err) + } + + interfaces := map[string]NetworkInterface{} + for name, iface := range networks { + var netface NetworkInterface + + rawIP := strings.TrimSpace(iface["IPAddress"]) + if rawIP != "" { + ip := net.ParseIP(rawIP) + if ip == nil { + return nil, fmt.Errorf("invalid IP: %q", rawIP) + } + // Docker's IPAddress field is IPv4. The IPv6 address + // is stored in the GlobalIPv6Address field. + netface.IPv4 = ip + } + + rawMAC := strings.TrimSpace(iface["MacAddress"]) + if rawMAC != "" { + mac, err := net.ParseMAC(rawMAC) + if err != nil { + return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err) + } + netface.MAC = mac + } + + interfaces[name] = netface + } + + return interfaces, nil +} + // SandboxPid returns the PID to the sandbox process. func (d *Docker) SandboxPid() (int, error) { out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput() diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD deleted file mode 100644 index 2dcba84ae..000000000 --- a/pkg/tmutex/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tmutex", - srcs = ["tmutex.go"], - visibility = ["//:sandbox"], -) - -go_test( - name = "tmutex_test", - size = "medium", - srcs = ["tmutex_test.go"], - library = ":tmutex", - deps = ["//pkg/sync"], -) diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go deleted file mode 100644 index c4685020d..000000000 --- a/pkg/tmutex/tmutex.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tmutex provides the implementation of a mutex that implements an -// efficient TryLock function in addition to Lock and Unlock. -package tmutex - -import ( - "sync/atomic" -) - -// Mutex is a mutual exclusion primitive that implements TryLock in addition -// to Lock and Unlock. -type Mutex struct { - v int32 - ch chan struct{} -} - -// Init initializes the mutex. -func (m *Mutex) Init() { - m.v = 1 - m.ch = make(chan struct{}, 1) -} - -// Lock acquires the mutex. If it is currently held by another goroutine, Lock -// will wait until it has a chance to acquire it. -func (m *Mutex) Lock() { - // Uncontended case. - if atomic.AddInt32(&m.v, -1) == 0 { - return - } - - for { - // Try to acquire the mutex again, at the same time making sure - // that m.v is negative, which indicates to the owner of the - // lock that it is contended, which will force it to try to wake - // someone up when it releases the mutex. - if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 { - return - } - - // Wait for the mutex to be released before trying again. - <-m.ch - } -} - -// TryLock attempts to acquire the mutex without blocking. If the mutex is -// currently held by another goroutine, it fails to acquire it and returns -// false. -func (m *Mutex) TryLock() bool { - v := atomic.LoadInt32(&m.v) - if v <= 0 { - return false - } - return atomic.CompareAndSwapInt32(&m.v, 1, 0) -} - -// Unlock releases the mutex. -func (m *Mutex) Unlock() { - if atomic.SwapInt32(&m.v, 1) == 0 { - // There were no pending waiters. - return - } - - // Wake some waiter up. - select { - case m.ch <- struct{}{}: - default: - } -} diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go deleted file mode 100644 index 05540696a..000000000 --- a/pkg/tmutex/tmutex_test.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tmutex - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" -) - -func TestBasicLock(t *testing.T) { - var m Mutex - m.Init() - - m.Lock() - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - ch <- struct{}{} - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } - - // Make sure we can lock and unlock again. - m.Lock() - m.Unlock() -} - -func TestTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Try to lock. It should succeed. - if !m.TryLock() { - t.Fatalf("TryLock failed on unlocked mutex") - } - - // Try to lock again, it should now fail. - if m.TryLock() { - t.Fatalf("TryLock succeeded on locked mutex") - } - - // Try blocking lock the mutex from a different goroutine. This must - // not block because the mutex is held. - ch := make(chan struct{}, 1) - go func() { - m.Lock() - ch <- struct{}{} - m.Unlock() - }() - - select { - case <-ch: - t.Fatalf("Lock succeeded on locked mutex") - case <-time.After(100 * time.Millisecond): - } - - // Unlock the mutex and make sure that the goroutine waiting on Lock() - // unblocks and succeeds. - m.Unlock() - - select { - case <-ch: - case <-time.After(100 * time.Millisecond): - t.Fatalf("Lock failed to acquire unlocked mutex") - } -} - -func TestMutualExclusion(t *testing.T) { - var m Mutex - m.Init() - - // Test mutual exclusion by running "gr" goroutines concurrently, and - // have each one increment a counter "iters" times within the critical - // section established by the mutex. - // - // If at the end the counter is not gr * iters, then we know that - // goroutines ran concurrently within the critical section. - // - // If one of the goroutines doesn't complete, it's likely a bug that - // causes to it to wait forever. - const gr = 1000 - const iters = 100000 - v := 0 - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(1) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - } - - wg.Wait() - - if v != gr*iters { - t.Fatalf("Bad count: got %v, want %v", v, gr*iters) - } -} - -func TestMutualExclusionWithTryLock(t *testing.T) { - var m Mutex - m.Init() - - // Similar to the previous, with the addition of some goroutines that - // only increment the count if TryLock succeeds. - const gr = 1000 - const iters = 100000 - total := int64(gr * iters) - var tryTotal int64 - v := int64(0) - var wg sync.WaitGroup - for i := 0; i < gr; i++ { - wg.Add(2) - go func() { - for j := 0; j < iters; j++ { - m.Lock() - v++ - m.Unlock() - } - wg.Done() - }() - go func() { - local := int64(0) - for j := 0; j < iters; j++ { - if m.TryLock() { - v++ - m.Unlock() - local++ - } - } - atomic.AddInt64(&tryTotal, local) - wg.Done() - }() - } - - wg.Wait() - - t.Logf("tryTotal = %d", tryTotal) - total += tryTotal - - if v != total { - t.Fatalf("Bad count: got %v, want %v", v, total) - } -} - -// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following -// differences: -// -// - The number of goroutines is variable, with the maximum value depending on -// GOMAXPROCS. -// -// - The number of iterations per benchmark is controlled by the benchmarking -// framework. -// -// - Care is taken to ensure that all goroutines participating in the benchmark -// have been created before the benchmark begins. -func BenchmarkTmutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m Mutex - m.Init() - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} - -// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as -// a comparison point. -func BenchmarkSyncMutex(b *testing.B) { - for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { - var m sync.Mutex - - var ready sync.WaitGroup - begin := make(chan struct{}) - var end sync.WaitGroup - for i := 0; i < n; i++ { - ready.Add(1) - end.Add(1) - go func() { - ready.Done() - <-begin - for j := 0; j < b.N; j++ { - m.Lock() - m.Unlock() - } - end.Done() - }() - } - - ready.Wait() - b.ResetTimer() - close(begin) - end.Wait() - }) - } -} diff --git a/pkg/usermem/addr.go b/pkg/usermem/addr.go index e79210804..c4100481e 100644 --- a/pkg/usermem/addr.go +++ b/pkg/usermem/addr.go @@ -106,3 +106,20 @@ func (ar AddrRange) IsPageAligned() bool { func (ar AddrRange) String() string { return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End) } + +// PageRoundDown/Up are equivalent to Addr.RoundDown/Up, but without the +// potentially truncating conversion from uint64 to Addr. This is necessary +// because there is no way to define generic "PageRoundDown/Up" functions in Go. + +// PageRoundDown returns x rounded down to the nearest page boundary. +func PageRoundDown(x uint64) uint64 { + return x &^ (PageSize - 1) +} + +// PageRoundUp returns x rounded up to the nearest page boundary. +// ok is true iff rounding up did not wrap around. +func PageRoundUp(x uint64) (addr uint64, ok bool) { + addr = PageRoundDown(x + PageSize - 1) + ok = addr >= x + return +} |