diff options
30 files changed, 958 insertions, 298 deletions
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD index 1186f788e..2a2e3d1aa 100644 --- a/pkg/buffer/BUILD +++ b/pkg/buffer/BUILD @@ -38,6 +38,7 @@ go_test( name = "buffer_test", size = "small", srcs = [ + "buffer_test.go", "pool_test.go", "safemem_test.go", "view_test.go", diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index 311808ae9..5b77a6a3f 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -33,12 +33,40 @@ func (b *buffer) init(size int) { b.data = make([]byte, size) } +// initWithData initializes b with data, taking ownership. +func (b *buffer) initWithData(data []byte) { + b.data = data + b.read = 0 + b.write = len(data) +} + // Reset resets read and write locations, effectively emptying the buffer. func (b *buffer) Reset() { b.read = 0 b.write = 0 } +// Remove removes r from the unread portion. It returns false if r does not +// fully reside in b. +func (b *buffer) Remove(r Range) bool { + sz := b.ReadSize() + switch { + case r.Len() != r.Intersect(Range{end: sz}).Len(): + return false + case r.Len() == 0: + // Noop + case r.begin == 0: + b.read += r.end + case r.end == sz: + b.write -= r.Len() + default: + // Remove from the middle of b.data. + copy(b.data[b.read+r.begin:], b.data[b.read+r.end:b.write]) + b.write -= r.Len() + } + return true +} + // Full indicates the buffer is full. // // This indicates there is no capacity left to write. diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go new file mode 100644 index 000000000..32db841e4 --- /dev/null +++ b/pkg/buffer/buffer_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "bytes" + "testing" +) + +func TestBufferRemove(t *testing.T) { + sample := []byte("01234567") + + // Success cases + for _, tc := range []struct { + desc string + data []byte + rng Range + want []byte + }{ + { + desc: "empty slice", + }, + { + desc: "empty range", + data: sample, + want: sample, + }, + { + desc: "empty range with positive begin", + data: sample, + rng: Range{begin: 1, end: 1}, + want: sample, + }, + { + desc: "range at beginning", + data: sample, + rng: Range{begin: 0, end: 1}, + want: sample[1:], + }, + { + desc: "range in middle", + data: sample, + rng: Range{begin: 2, end: 4}, + want: []byte("014567"), + }, + { + desc: "range at end", + data: sample, + rng: Range{begin: 7, end: 8}, + want: sample[:7], + }, + { + desc: "range all", + data: sample, + rng: Range{begin: 0, end: 8}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var buf buffer + buf.initWithData(tc.data) + if ok := buf.Remove(tc.rng); !ok { + t.Errorf("buf.Remove(%#v) = false, want true", tc.rng) + } else if got := buf.ReadSlice(); !bytes.Equal(got, tc.want) { + t.Errorf("buf.ReadSlice() = %q, want %q", got, tc.want) + } + }) + } + + // Failure cases + for _, tc := range []struct { + desc string + data []byte + rng Range + }{ + { + desc: "begin out-of-range", + data: sample, + rng: Range{begin: -1, end: 4}, + }, + { + desc: "end out-of-range", + data: sample, + rng: Range{begin: 4, end: 9}, + }, + { + desc: "both out-of-range", + data: sample, + rng: Range{begin: -100, end: 100}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var buf buffer + buf.initWithData(tc.data) + if ok := buf.Remove(tc.rng); ok { + t.Errorf("buf.Remove(%#v) = true, want false", tc.rng) + } + }) + } +} diff --git a/pkg/buffer/pool.go b/pkg/buffer/pool.go index 7ad6132ab..2ec41dd4f 100644 --- a/pkg/buffer/pool.go +++ b/pkg/buffer/pool.go @@ -42,6 +42,13 @@ type pool struct { // get gets a new buffer from p. func (p *pool) get() *buffer { + buf := p.getNoInit() + buf.init(p.bufferSize) + return buf +} + +// get gets a new buffer from p without initializing it. +func (p *pool) getNoInit() *buffer { if p.avail == nil { p.avail = p.embeddedStorage[:] } @@ -52,7 +59,6 @@ func (p *pool) get() *buffer { p.bufferSize = defaultBufferSize } buf := &p.avail[0] - buf.init(p.bufferSize) p.avail = p.avail[1:] return buf } @@ -62,6 +68,7 @@ func (p *pool) put(buf *buffer) { // Remove reference to the underlying storage, allowing it to be garbage // collected. buf.data = nil + buf.Reset() } // setBufferSize sets the size of underlying storage buffer for future diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index 00652d675..7bcfcd543 100644 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go @@ -19,6 +19,9 @@ import ( "io" ) +// Buffer is an alias to View. +type Buffer = View + // View is a non-linear buffer. // // All methods are thread compatible. @@ -39,6 +42,51 @@ func (v *View) TrimFront(count int64) { } } +// Remove deletes data at specified location in v. It returns false if specified +// range does not fully reside in v. +func (v *View) Remove(offset, length int) bool { + if offset < 0 || length < 0 { + return false + } + tgt := Range{begin: offset, end: offset + length} + if tgt.Len() != tgt.Intersect(Range{end: int(v.size)}).Len() { + return false + } + + // Scan through each buffer and remove intersections. + var curr Range + for buf := v.data.Front(); buf != nil; { + origLen := buf.ReadSize() + curr.end = curr.begin + origLen + + if x := curr.Intersect(tgt); x.Len() > 0 { + if !buf.Remove(x.Offset(-curr.begin)) { + panic("buf.Remove() failed") + } + if buf.ReadSize() == 0 { + // buf fully removed, removing it from the list. + oldBuf := buf + buf = buf.Next() + v.data.Remove(oldBuf) + v.pool.put(oldBuf) + } else { + // Only partial data intersects, moving on to next one. + buf = buf.Next() + } + v.size -= int64(x.Len()) + } else { + // This buffer is not in range, moving on to next one. + buf = buf.Next() + } + + curr.begin += origLen + if curr.begin >= tgt.end { + break + } + } + return true +} + // ReadAt implements io.ReaderAt.ReadAt. func (v *View) ReadAt(p []byte, offset int64) (int, error) { var ( @@ -81,7 +129,6 @@ func (v *View) advanceRead(count int64) { oldBuf := buf buf = buf.Next() // Iterate. v.data.Remove(oldBuf) - oldBuf.Reset() v.pool.put(oldBuf) // Update counts. @@ -118,7 +165,6 @@ func (v *View) Truncate(length int64) { // Drop the buffer completely; see above. v.data.Remove(buf) - buf.Reset() v.pool.put(buf) v.size -= sz } @@ -224,6 +270,78 @@ func (v *View) Append(data []byte) { } } +// AppendOwned takes ownership of data and appends it to v. +func (v *View) AppendOwned(data []byte) { + if len(data) > 0 { + buf := v.pool.getNoInit() + buf.initWithData(data) + v.data.PushBack(buf) + v.size += int64(len(data)) + } +} + +// PullUp makes the specified range contiguous and returns the backing memory. +func (v *View) PullUp(offset, length int) ([]byte, bool) { + if length == 0 { + return nil, true + } + tgt := Range{begin: offset, end: offset + length} + if tgt.Intersect(Range{end: int(v.size)}).Len() != length { + return nil, false + } + + curr := Range{} + buf := v.data.Front() + for ; buf != nil; buf = buf.Next() { + origLen := buf.ReadSize() + curr.end = curr.begin + origLen + + if x := curr.Intersect(tgt); x.Len() == tgt.Len() { + // buf covers the whole requested target range. + sub := x.Offset(-curr.begin) + return buf.ReadSlice()[sub.begin:sub.end], true + } else if x.Len() > 0 { + // buf is pointing at the starting buffer we want to merge. + break + } + + curr.begin += origLen + } + + // Calculate the total merged length. + totLen := 0 + for n := buf; n != nil; n = n.Next() { + totLen += n.ReadSize() + if curr.begin+totLen >= tgt.end { + break + } + } + + // Merge the buffers. + data := make([]byte, totLen) + off := 0 + for n := buf; n != nil && off < totLen; { + copy(data[off:], n.ReadSlice()) + off += n.ReadSize() + + // Remove buffers except for the first one, which will be reused. + if n == buf { + n = n.Next() + } else { + old := n + n = n.Next() + v.data.Remove(old) + v.pool.put(old) + } + } + + // Update the first buffer with merged data. + buf.initWithData(data) + + r := tgt.Offset(-curr.begin) + return buf.data[r.begin:r.end], true +} + // Flatten returns a flattened copy of this data. // // This method should not be used in any performance-sensitive paths. It may @@ -267,6 +385,27 @@ func (v *View) Apply(fn func([]byte)) { } } +// SubApply applies fn to a given range of data in v. Any part of the range +// outside of v is ignored. +func (v *View) SubApply(offset, length int, fn func([]byte)) { + for buf := v.data.Front(); length > 0 && buf != nil; buf = buf.Next() { + d := buf.ReadSlice() + if offset >= len(d) { + offset -= len(d) + continue + } + if offset > 0 { + d = d[offset:] + offset = 0 + } + if length < len(d) { + d = d[:length] + } + fn(d) + length -= len(d) + } +} + // Merge merges the provided View with this one. // // The other view will be appended to v, and other will be empty after this @@ -389,3 +528,39 @@ func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { } return done, err } + +// A Range specifies a range of buffer. +type Range struct { + begin int + end int +} + +// Intersect returns the intersection of x and y. +func (x Range) Intersect(y Range) Range { + if x.begin < y.begin { + x.begin = y.begin + } + if x.end > y.end { + x.end = y.end + } + if x.begin >= x.end { + return Range{} + } + return x +} + +// Offset returns x offset by off. +func (x Range) Offset(off int) Range { + x.begin += off + x.end += off + return x +} + +// Len returns the length of x. +func (x Range) Len() int { + l := x.end - x.begin + if l < 0 { + l = 0 + } + return l +} diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 839af0223..796efa240 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -17,7 +17,9 @@ package buffer import ( "bytes" "context" + "fmt" "io" + "reflect" "strings" "testing" @@ -237,6 +239,18 @@ func TestView(t *testing.T) { }, }, + // AppendOwned. + { + name: "append-owned", + input: "hello", + output: "hello world", + op: func(t *testing.T, v *View) { + b := []byte("Xworld") + v.AppendOwned(b) + b[0] = ' ' + }, + }, + // Truncate. { name: "truncate", @@ -495,6 +509,267 @@ func TestView(t *testing.T) { } } +func TestViewPullUp(t *testing.T) { + for _, tc := range []struct { + desc string + inputs []string + offset int + length int + output string + failed bool + // lengths is the lengths of each buffer node after the pull up. + lengths []int + }{ + { + desc: "whole empty view", + }, + { + desc: "zero pull", + inputs: []string{"hello", " world"}, + lengths: []int{5, 6}, + }, + { + desc: "whole view", + inputs: []string{"hello", " world"}, + offset: 0, + length: 11, + output: "hello world", + lengths: []int{11}, + }, + { + desc: "middle to end aligned", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 4, + length: 10, + output: "456789abcd", + lengths: []int{4, 10}, + }, + { + desc: "middle to end unaligned", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 6, + length: 8, + output: "6789abcd", + lengths: []int{4, 10}, + }, + { + desc: "middle aligned", + inputs: []string{"0123", "45678", "9abcd", "efgh"}, + offset: 6, + length: 5, + output: "6789a", + lengths: []int{4, 10, 4}, + }, + + // Failed cases. + { + desc: "empty view - length too long", + offset: 0, + length: 1, + failed: true, + }, + { + desc: "empty view - offset too large", + offset: 1, + length: 1, + failed: true, + }, + { + desc: "length too long", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 4, + length: 100, + failed: true, + lengths: []int{4, 5, 5}, + }, + { + desc: "offset too large", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 100, + length: 1, + failed: true, + lengths: []int{4, 5, 5}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.inputs { + v.AppendOwned([]byte(s)) + } + + got, gotOk := v.PullUp(tc.offset, tc.length) + want, wantOk := []byte(tc.output), !tc.failed + if gotOk != wantOk || !bytes.Equal(got, want) { + t.Errorf("v.PullUp(%d, %d) = %q, %t; %q, %t", tc.offset, tc.length, got, gotOk, want, wantOk) + } + + var gotLengths []int + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + gotLengths = append(gotLengths, buf.ReadSize()) + } + if !reflect.DeepEqual(gotLengths, tc.lengths) { + t.Errorf("lengths = %v; want %v", gotLengths, tc.lengths) + } + }) + } +} + +func TestViewRemove(t *testing.T) { + // Success cases + for _, tc := range []struct { + desc string + // before is the contents for each buffer node initially. + before []string + // after is the contents for each buffer node after removal. + after []string + offset int + length int + }{ + { + desc: "empty view", + }, + { + desc: "nothing removed", + before: []string{"hello", " world"}, + after: []string{"hello", " world"}, + }, + { + desc: "whole view", + before: []string{"hello", " world"}, + offset: 0, + length: 11, + }, + { + desc: "beginning to middle aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"9abcd"}, + offset: 0, + length: 9, + }, + { + desc: "beginning to middle unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"678", "9abcd"}, + offset: 0, + length: 6, + }, + { + desc: "middle to end aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123"}, + offset: 4, + length: 10, + }, + { + desc: "middle to end unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "45"}, + offset: 6, + length: 8, + }, + { + desc: "middle aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "9abcd"}, + offset: 4, + length: 5, + }, + { + desc: "middle unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "4578", "9abcd"}, + offset: 6, + length: 1, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.before { + v.AppendOwned([]byte(s)) + } + + if ok := v.Remove(tc.offset, tc.length); !ok { + t.Errorf("v.Remove(%d, %d) = false, want true", tc.offset, tc.length) + } + + var got []string + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + got = append(got, string(buf.ReadSlice())) + } + if !reflect.DeepEqual(got, tc.after) { + t.Errorf("after = %v; want %v", got, tc.after) + } + }) + } + + // Failure cases + for _, tc := range []struct { + desc string + // before is the contents for each buffer node initially. + before []string + offset int + length int + }{ + { + desc: "offset out-of-range", + before: []string{"hello", " world"}, + offset: -1, + length: 3, + }, + { + desc: "length too long", + before: []string{"hello", " world"}, + offset: 0, + length: 12, + }, + { + desc: "length too long with positive offset", + before: []string{"hello", " world"}, + offset: 3, + length: 9, + }, + { + desc: "length negative", + before: []string{"hello", " world"}, + offset: 0, + length: -1, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.before { + v.AppendOwned([]byte(s)) + } + if ok := v.Remove(tc.offset, tc.length); ok { + t.Errorf("v.Remove(%d, %d) = true, want false", tc.offset, tc.length) + } + }) + } +} + +func TestViewSubApply(t *testing.T) { + var v View + v.AppendOwned([]byte("0123")) + v.AppendOwned([]byte("45678")) + v.AppendOwned([]byte("9abcd")) + + data := []byte("0123456789abcd") + + for i := 0; i <= len(data); i++ { + for j := i; j <= len(data); j++ { + t.Run(fmt.Sprintf("SubApply(%d,%d)", i, j), func(t *testing.T) { + var got []byte + v.SubApply(i, j-i, func(b []byte) { + got = append(got, b...) + }) + if want := data[i:j]; !bytes.Equal(got, want) { + t.Errorf("got = %q; want %q", got, want) + } + }) + } + } +} + func doSaveAndLoad(t *testing.T, toSave, toLoad *View) { t.Helper() var buf bytes.Buffer @@ -542,3 +817,84 @@ func TestSaveRestoreView(t *testing.T) { t.Errorf("v.Flatten() = %x, want %x", got, data) } } + +func TestRangeIntersect(t *testing.T) { + for _, tc := range []struct { + desc string + x, y, want Range + }{ + { + desc: "empty intersects empty", + }, + { + desc: "empty intersection", + x: Range{end: 10}, + y: Range{begin: 10, end: 20}, + }, + { + desc: "some intersection", + x: Range{begin: 5, end: 20}, + y: Range{end: 10}, + want: Range{begin: 5, end: 10}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + if got := tc.x.Intersect(tc.y); got != tc.want { + t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.x, tc.y, got, tc.want) + } + if got := tc.y.Intersect(tc.x); got != tc.want { + t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.y, tc.x, got, tc.want) + } + }) + } +} + +func TestRangeOffset(t *testing.T) { + for _, tc := range []struct { + input Range + offset int + output Range + }{ + { + input: Range{}, + offset: 0, + output: Range{}, + }, + { + input: Range{}, + offset: -1, + output: Range{begin: -1, end: -1}, + }, + { + input: Range{begin: 10, end: 20}, + offset: -1, + output: Range{begin: 9, end: 19}, + }, + { + input: Range{begin: 10, end: 20}, + offset: 2, + output: Range{begin: 12, end: 22}, + }, + } { + if got := tc.input.Offset(tc.offset); got != tc.output { + t.Errorf("(%#v).Offset(%d) = %#v, want %#v", tc.input, tc.offset, got, tc.output) + } + } +} + +func TestRangeLen(t *testing.T) { + for _, tc := range []struct { + r Range + want int + }{ + {r: Range{}, want: 0}, + {r: Range{begin: 1, end: 1}, want: 0}, + {r: Range{begin: -1, end: -1}, want: 0}, + {r: Range{end: 10}, want: 10}, + {r: Range{begin: 5, end: 10}, want: 5}, + } { + if got := tc.r.Len(); got != tc.want { + t.Errorf("(%#v).Len() = %d, want %d", tc.r, got, tc.want) + } + } +} diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index e6fe0fc0d..daff40cd5 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -36,7 +36,7 @@ const Name = "devtmpfs" // // +stateify savable type FilesystemType struct { - initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1664): not yet supported. + initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. initErr error // fs is the tmpfs filesystem that backs all mounts of this FilesystemType. diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index e94aceb92..fa5456eee 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -564,7 +564,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] natRange.UnmarshalUnsafe(buf) - // TODO(gvisor.dev/issue/5689): Support port or address ranges. + // TODO(gvisor.dev/issue/5697): Support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { nflog("snatTargetMakerV6: MinAddr and MaxAddr are different") return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index b215067cf..9cc1c57d7 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -470,11 +470,8 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { // SetForwarding implements inet.Stack.SetForwarding. func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { - switch protocol { - case ipv4.ProtocolNumber, ipv6.ProtocolNumber: - s.Stack.SetForwarding(protocol, enable) - default: - panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + if err := s.Stack.SetForwardingDefaultAndAllNICs(protocol, enable); err != nil { + return fmt.Errorf("SetForwardingDefaultAndAllNICs(%d, %t): %s", protocol, enable, err) } return nil } diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index ebb4b2c1d..1c913b5e1 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -60,9 +60,13 @@ func IPv4(pkt *stack.PacketBuffer) bool { return false } ipHdr = header.IPv4(hdr) + length := int(ipHdr.TotalLength()) - len(hdr) + if length < 0 { + return false + } pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(length) return true } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 90075a70c..56b76a284 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -167,8 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s resPkt := r.holes[0].pkt for i := 1; i < len(r.holes); i++ { - fragData := r.holes[i].pkt.Data() - resPkt.Data().ReadFromData(fragData, fragData.Size()) + stack.MergeFragment(resPkt, r.holes[i].pkt) } return resPkt, r.proto, true, memConsumed, nil } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 74aad126c..bd63e0289 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) @@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, false); err != nil { - t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 3c8a39973..5f45b9ee6 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -389,8 +389,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e457be3cf..040cd4bc8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 880290b4b..febbb3f38 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1472,13 +1472,19 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + // Calculate the number of octets parsed from data. We want to remove all + // the data except the unparsed portion located at the end, which its size + // is extHdr.Buf.Size(). + trim := pkt.Data().Size() - extHdr.Buf.Size() + // For unfragmented packets, extHdr still contains the transport header. // Get rid of it. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) - pkt.Data().Replace(extHdr.Buf) + trim += pkt.TransportHeader().View().Size() + + pkt.Data().DeleteFront(trim) stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index faf6a782e..30325160a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -3329,8 +3329,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) } transportProtocol := header.ICMPv6ProtocolNumber diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 52b9a200c..570c6c00c 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -878,8 +878,9 @@ func TestNDPValidation(t *testing.T) { s, ep := setup(t) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } stats := s.Stats().ICMP.V6.PacketsReceived diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 63ab31083..84aa6a9e4 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -74,6 +74,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/atomicbitops", + "//pkg/buffer", "//pkg/ilist", "//pkg/log", "//pkg/rand", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7d3725681..ff555722e 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -367,8 +367,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f }}, }) - // Enable forwarding. - s.SetForwarding(proto.Number(), true) + protoNum := proto.Number() + if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) + } // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index c585b81b2..d4ac9e1f8 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1220,8 +1220,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { NDPDisp: &ndpDisp, })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } e := channel.New(1, 1280, linkAddr1) @@ -1424,8 +1424,8 @@ func TestRouterDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -1626,8 +1626,8 @@ func TestPrefixDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) @@ -1893,8 +1893,8 @@ func TestAutoGenAddr(t *testing.T) { })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -4771,8 +4771,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { // or routers, or auto-generated address. for _, forwarding := range [...]bool{true, false} { t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { case e := <-ndpDisp.routerC: @@ -5353,8 +5353,8 @@ func TestRouterSolicitation(t *testing.T) { name: "Handle RAs always", handleRAs: ipv6.HandlingRAsAlwaysEnabled, afterFirstRS: func(t *testing.T, s *stack.Stack) { - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } }, }, @@ -5481,11 +5481,17 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) + } }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } }, }, diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index fc3c54e34..e2e073091 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -16,9 +16,10 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + tcpipbuffer "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -39,7 +40,7 @@ type PacketBufferOptions struct { // Data is the initial unparsed data for the new packet. If set, it will be // owned by the new packet. - Data buffer.VectorisedView + Data tcpipbuffer.VectorisedView // IsForwardedPacket identifies that the PacketBuffer being created is for a // forwarded packet. @@ -56,6 +57,34 @@ type PacketBufferOptions struct { // empty. Use of PacketBuffer in any other order is unsupported. // // PacketBuffer must be created with NewPacketBuffer. +// +// Internal structure: A PacketBuffer holds a pointer to buffer.Buffer, which +// exposes a logically-contiguous byte storage. The underlying storage structure +// is abstracted out, and should not be a concern here for most of the time. +// +// |- reserved ->| +// |--->| consumed (incoming) +// 0 V V +// +--------+----+----+--------------------+ +// | | | | current data ... | (buf) +// +--------+----+----+--------------------+ +// ^ | +// |<---| pushed (outgoing) +// +// When a PacketBuffer is created, a `reserved` header region can be specified, +// which stack pushes headers in this region for an outgoing packet. There could +// be no such region for an incoming packet, and `reserved` is 0. The value of +// `reserved` never changes in the entire lifetime of the packet. +// +// Outgoing Packet: When a header is pushed, `pushed` gets incremented by the +// pushed length, and the current value is stored for each header. PacketBuffer +// substracts this value from `reserved` to compute the starting offset of each +// header in `buf`. +// +// Incoming Packet: When a header is consumed (a.k.a. parsed), the current +// `consumed` value is stored for each header, and it gets incremented by the +// consumed length. PacketBuffer adds this value to `reserved` to compute the +// starting offset of each header in `buf`. type PacketBuffer struct { _ sync.NoCopy @@ -63,28 +92,16 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // data holds the payload of the packet. - // - // For inbound packets, Data is initially the whole packet. Then gets moved to - // headers via PacketHeader.Consume, when the packet is being parsed. - // - // For outbound packets, Data is the innermost layer, defined by the protocol. - // Headers are pushed in front of it via PacketHeader.Push. - // - // The bytes backing Data are immutable, a.k.a. users shouldn't write to its - // backing storage. - data buffer.VectorisedView + // buf is the underlying buffer for the packet. See struct level docs for + // details. + buf *buffer.Buffer + reserved int + pushed int + consumed int // headers stores metadata about each header. headers [numHeaderType]headerInfo - // header is the internal storage for outbound packets. Headers will be pushed - // (prepended) on this storage as the packet is being constructed. - // - // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and - // data are held in the same underlying buffer storage. - header buffer.Prependable - // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -131,10 +148,14 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - data: opts.Data, + buf: &buffer.Buffer{}, } if opts.ReserveHeaderBytes != 0 { - pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + pk.buf.AppendOwned(make([]byte, opts.ReserveHeaderBytes)) + pk.reserved = opts.ReserveHeaderBytes + } + for _, v := range opts.Data.Views() { + pk.buf.AppendOwned(v) } if opts.IsForwardedPacket { pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket @@ -145,13 +166,13 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { // ReservedHeaderBytes returns the number of bytes initially reserved for // headers. func (pk *PacketBuffer) ReservedHeaderBytes() int { - return pk.header.UsedLength() + pk.header.AvailableLength() + return pk.reserved } // AvailableHeaderBytes returns the number of bytes currently available for // headers. This is relevant to PacketHeader.Push method only. func (pk *PacketBuffer) AvailableHeaderBytes() int { - return pk.header.AvailableLength() + return pk.reserved - pk.pushed } // LinkHeader returns the handle to link-layer header. @@ -180,24 +201,18 @@ func (pk *PacketBuffer) TransportHeader() PacketHeader { // HeaderSize returns the total size of all headers in bytes. func (pk *PacketBuffer) HeaderSize() int { - // Note for inbound packets (Consume called), headers are not stored in - // pk.header. Thus, calculation of size of each header is needed. - var size int - for i := range pk.headers { - size += len(pk.headers[i].buf) - } - return size + return pk.pushed + pk.consumed } // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.data.Size() + return int(pk.buf.Size()) - pk.headerOffset() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize + return int(pk.buf.Size()) + packetBufferStructSize } // Data returns the handle to data portion of pk. @@ -206,61 +221,65 @@ func (pk *PacketBuffer) Data() PacketData { } // Views returns the underlying storage of the whole packet. -func (pk *PacketBuffer) Views() []buffer.View { - // Optimization for outbound packets that headers are in pk.header. - useHeader := true - for i := range pk.headers { - if !canUseHeader(&pk.headers[i]) { - useHeader = false - break - } - } +func (pk *PacketBuffer) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := pk.headerOffset() + pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views +} - dataViews := pk.data.Views() - - var vs []buffer.View - if useHeader { - vs = make([]buffer.View, 0, 1+len(dataViews)) - vs = append(vs, pk.header.View()) - } else { - vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) - for i := range pk.headers { - if v := pk.headers[i].buf; len(v) > 0 { - vs = append(vs, v) - } - } - } - return append(vs, dataViews...) +func (pk *PacketBuffer) headerOffset() int { + return pk.reserved - pk.pushed +} + +func (pk *PacketBuffer) headerOffsetOf(typ headerType) int { + return pk.reserved + pk.headers[typ].offset } -func canUseHeader(h *headerInfo) bool { - // h.offset will be negative if the header was pushed in to prependable - // portion, or doesn't matter when it's empty. - return len(h.buf) == 0 || h.offset < 0 +func (pk *PacketBuffer) dataOffset() int { + return pk.reserved + pk.consumed } -func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { +func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] - if h.buf != nil { + if h.length > 0 { panic(fmt.Sprintf("push must not be called twice: type %s", typ)) } - h.buf = buffer.View(pk.header.Prepend(size)) - h.offset = -pk.header.UsedLength() - return h.buf + if pk.pushed+size > pk.reserved { + panic("not enough headroom reserved") + } + pk.pushed += size + h.offset = -pk.pushed + h.length = size + return pk.headerView(typ) } -func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { +func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, consumed bool) { h := &pk.headers[typ] - if h.buf != nil { + if h.length > 0 { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.data.PullUp(size) + if pk.headerOffset()+pk.consumed+size > int(pk.buf.Size()) { + return nil, false + } + h.offset = pk.consumed + h.length = size + pk.consumed += size + return pk.headerView(typ), true +} + +func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { + h := &pk.headers[typ] + if h.length == 0 { + return nil + } + v, ok := pk.buf.PullUp(pk.headerOffsetOf(typ), h.length) if !ok { - return + panic("PullUp failed") } - pk.data.TrimFront(size) - h.buf = v - return h.buf, true + return v } // Clone makes a shallow copy of pk. @@ -270,9 +289,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - data: pk.data.Clone(nil), + buf: pk.buf, + reserved: pk.reserved, + pushed: pk.pushed, + consumed: pk.consumed, headers: pk.headers, - header: pk.header, Hash: pk.Hash, Owner: pk.Owner, GSOOptions: pk.GSOOptions, @@ -306,9 +327,11 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - newPk := NewPacketBuffer(PacketBufferOptions{ - Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), - }) + newPk := &PacketBuffer{ + buf: pk.buf, + // Treat unfilled header portion as reserved. + reserved: pk.AvailableHeaderBytes(), + } // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to // maintain this flag in the packet. Currently conntrack needs this flag to // tell if a noop connection should be inserted at Input hook. Once conntrack @@ -322,15 +345,12 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // headerInfo stores metadata about a header in a packet. type headerInfo struct { - // buf is the memorized slice for both prepended and consumed header. - // When header is prepended, buf serves as memorized value, which is a slice - // of pk.header. When header is consumed, buf is the slice pulled out from - // pk.Data, which is the only place to hold this header. - buf buffer.View - - // offset will be a negative number denoting the offset where this header is - // from the end of pk.header, if it is prepended. Otherwise, zero. + // offset is the offset of the header in pk.buf relative to + // pk.buf[pk.reserved]. See the PacketBuffer struct for details. offset int + + // length is the length of this header. + length int } // PacketHeader is a handle object to a header in the underlying packet. @@ -340,14 +360,14 @@ type PacketHeader struct { } // View returns the underlying storage of h. -func (h PacketHeader) View() buffer.View { - return h.pk.headers[h.typ].buf +func (h PacketHeader) View() tcpipbuffer.View { + return h.pk.headerView(h.typ) } // Push pushes size bytes in the front of its residing packet, and returns the // backing storage. Callers may only call one of Push or Consume once on each // header in the lifetime of the underlying packet. -func (h PacketHeader) Push(size int) buffer.View { +func (h PacketHeader) Push(size int) tcpipbuffer.View { return h.pk.push(h.typ, size) } @@ -356,7 +376,7 @@ func (h PacketHeader) Push(size int) buffer.View { // size, consumed will be false, and the state of h will not be affected. // Callers may only call one of Push or Consume once on each header in the // lifetime of the underlying packet. -func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { +func (h PacketHeader) Consume(size int) (v tcpipbuffer.View, consumed bool) { return h.pk.consume(h.typ, size) } @@ -367,55 +387,84 @@ type PacketData struct { // PullUp returns a contiguous view of size bytes from the beginning of d. // Callers should not write to or keep the view for later use. -func (d PacketData) PullUp(size int) (buffer.View, bool) { - return d.pk.data.PullUp(size) +func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { + return d.pk.buf.PullUp(d.pk.dataOffset(), size) } // DeleteFront removes count from the beginning of d. It panics if count > // d.Size(). All backing storage references after the front of the d are // invalidated. func (d PacketData) DeleteFront(count int) { - d.pk.data.TrimFront(count) + if !d.pk.buf.Remove(d.pk.dataOffset(), count) { + panic("count > d.Size()") + } } // CapLength reduces d to at most length bytes. func (d PacketData) CapLength(length int) { - d.pk.data.CapLength(length) + if length < 0 { + panic("length < 0") + } + if currLength := d.Size(); currLength > length { + trim := currLength - length + d.pk.buf.Remove(int(d.pk.buf.Size())-trim, trim) + } } // Views returns the underlying storage of d in a slice of Views. Caller should // not modify the returned slice. -func (d PacketData) Views() []buffer.View { - return d.pk.data.Views() +func (d PacketData) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := d.pk.dataOffset() + d.pk.buf.SubApply(offset, int(d.pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views } // AppendView appends v into d, taking the ownership of v. -func (d PacketData) AppendView(v buffer.View) { - d.pk.data.AppendView(v) +func (d PacketData) AppendView(v tcpipbuffer.View) { + d.pk.buf.AppendOwned(v) } -// ReadFromData moves at most count bytes from the beginning of srcData to the -// end of d and returns the number of bytes moved. -func (d PacketData) ReadFromData(srcData PacketData, count int) int { - return srcData.pk.data.ReadToVV(&d.pk.data, count) +// MergeFragment appends the data portion of frag to dst. It takes ownership of +// frag and frag should not be used again. +func MergeFragment(dst, frag *PacketBuffer) { + frag.buf.TrimFront(int64(frag.dataOffset())) + dst.buf.Merge(frag.buf) } // ReadFromVV moves at most count bytes from the beginning of srcVV to the end // of d and returns the number of bytes moved. -func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { - return srcVV.ReadToVV(&d.pk.data, count) +func (d PacketData) ReadFromVV(srcVV *tcpipbuffer.VectorisedView, count int) int { + done := 0 + for _, v := range srcVV.Views() { + if len(v) < count { + count -= len(v) + done += len(v) + d.pk.buf.AppendOwned(v) + } else { + v = v[:count] + count -= len(v) + done += len(v) + d.pk.buf.Append(v) + break + } + } + srcVV.TrimFront(done) + return done } // Size returns the number of bytes in the data payload of the packet. func (d PacketData) Size() int { - return d.pk.data.Size() + return int(d.pk.buf.Size()) - d.pk.dataOffset() } // AsRange returns a Range representing the current data payload of the packet. func (d PacketData) AsRange() Range { return Range{ pk: d.pk, - offset: d.pk.HeaderSize(), + offset: d.pk.dataOffset(), length: d.Size(), } } @@ -425,17 +474,12 @@ func (d PacketData) AsRange() Range { // // This method exists for compatibility between PacketBuffer and VectorisedView. // It may be removed later and should be used with care. -func (d PacketData) ExtractVV() buffer.VectorisedView { - return d.pk.data -} - -// Replace replaces the data portion of the packet with vv, taking the ownership -// of vv. -// -// This method exists for compatibility between PacketBuffer and VectorisedView. -// It may be removed later and should be used with care. -func (d PacketData) Replace(vv buffer.VectorisedView) { - d.pk.data = vv +func (d PacketData) ExtractVV() tcpipbuffer.VectorisedView { + var vv tcpipbuffer.VectorisedView + d.pk.buf.SubApply(d.pk.dataOffset(), d.pk.Size(), func(v []byte) { + vv.AppendView(v) + }) + return vv } // Range represents a contiguous subportion of a PacketBuffer. @@ -479,9 +523,9 @@ func (r Range) Capped(max int) Range { // AsView returns the backing storage of r if possible. It will allocate a new // View if r spans multiple pieces internally. Caller should not write to the // returned View in any way. -func (r Range) AsView() buffer.View { +func (r Range) AsView() tcpipbuffer.View { var allocated bool - var v buffer.View + var v tcpipbuffer.View r.iterate(func(b []byte) { if v == nil { // v has not been assigned, allowing first view to be returned. @@ -502,7 +546,7 @@ func (r Range) AsView() buffer.View { } // ToOwnedView returns a owned copy of data in r. -func (r Range) ToOwnedView() buffer.View { +func (r Range) ToOwnedView() tcpipbuffer.View { if r.length == 0 { return nil } @@ -523,63 +567,7 @@ func (r Range) Checksum() uint16 { // iterate calls fn for each piece in r. fn is always called with a non-empty // slice. func (r Range) iterate(fn func([]byte)) { - w := window{ - offset: r.offset, - length: r.length, - } - // Header portion. - for i := range r.pk.headers { - if b := w.process(r.pk.headers[i].buf); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - // Data portion. - if !w.isDone() { - for _, v := range r.pk.data.Views() { - if b := w.process(v); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - } -} - -// window represents contiguous region of byte stream. User would call process() -// to input bytes, and obtain a subslice that is inside the window. -type window struct { - offset int - length int -} - -// isDone returns true if the window has passed and further process() calls will -// always return an empty slice. This can be used to end processing early. -func (w *window) isDone() bool { - return w.length == 0 -} - -// process feeds b in and returns a subslice that is inside the window. The -// returned slice will be a subslice of b, and it does not keep b after method -// returns. This method may return an empty slice if nothing in b is inside the -// window. -func (w *window) process(b []byte) (inWindow []byte) { - if w.offset >= len(b) { - w.offset -= len(b) - return nil - } - if w.offset > 0 { - b = b[w.offset:] - w.offset = 0 - } - if w.length < len(b) { - b = b[:w.length] - } - w.length -= len(b) - return b + r.pk.buf.SubApply(r.offset, r.length, fn) } // PayloadSince returns packet payload starting from and including a particular @@ -587,21 +575,14 @@ func (w *window) process(b []byte) (inWindow []byte) { // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. -func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.data.Size() - for _, hinfo := range h.pk.headers[h.typ:] { - size += len(hinfo.buf) +func PayloadSince(h PacketHeader) tcpipbuffer.View { + offset := h.pk.headerOffset() + for i := headerType(0); i < h.typ; i++ { + offset += h.pk.headers[i].length } - - v := make(buffer.View, 0, size) - - for _, hinfo := range h.pk.headers[h.typ:] { - v = append(v, hinfo.buf...) - } - - for _, view := range h.pk.data.Views() { - v = append(v, view...) - } - - return v + return Range{ + pk: h.pk, + offset: offset, + length: int(h.pk.buf.Size()) - offset, + }.ToOwnedView() } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index bd4eb4fed..1c1aeb950 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -444,23 +444,8 @@ func TestPacketBufferData(t *testing.T) { checkData(t, pkt, []byte(tc.data+s)) }) - // ReadFromData/VV + // ReadFromVV for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { - t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { - s := "TO READ" - otherPkt := NewPacketBuffer(PacketBufferOptions{ - Data: vv(s, s), - }) - s += s - - pkt := tc.makePkt(t) - pkt.Data().ReadFromData(otherPkt.Data(), n) - - if n < len(s) { - s = s[:n] - } - checkData(t, pkt, []byte(tc.data+s)) - }) t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { s := "TO READ" srcVV := vv(s, s) @@ -487,16 +472,6 @@ func TestPacketBufferData(t *testing.T) { t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) } }) - - // Replace - t.Run("Replace", func(t *testing.T) { - s := "REPLACED" - - pkt := tc.makePkt(t) - pkt.Data().Replace(vv(s)) - - checkData(t, pkt, []byte(s)) - }) }) } } @@ -568,7 +543,7 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { t.Helper() if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { - t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) } if got := pkt.Data().Size(); got != len(want) { t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3d9e1e286..483a960c8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -492,9 +492,9 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs for the -// passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { return &tcpip.ErrUnknownProtocol{} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2c40cc43..ff88b1bd3 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4220,8 +4220,8 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) } - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) @@ -4275,8 +4275,8 @@ func TestFindRouteWithForwarding(t *testing.T) { // Disabling forwarding when the route is dependent on forwarding being // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) } { err := send(r, data) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index dbd279c94..42bc53328 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -475,11 +475,11 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) } - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3df1bbd68..87d36e1dd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -714,11 +714,11 @@ func TestExternalLoopbackTraffic(t *testing.T) { } if test.forwarding { - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 8fd9be32b..c8b9c9b5c 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -224,11 +224,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) } - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err) } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index 335e46499..66a6a0f68 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -342,25 +342,31 @@ func new(pid, cgroupsPath string) (*Cgroup, error) { // already exists, it means that the caller has already provided a // pre-configured cgroups, and 'res' is ignored. func (c *Cgroup) Install(res *specs.LinuxResources) error { - log.Debugf("Creating cgroup %q", c.Name) + log.Debugf("Installing cgroup path %q", c.Name) - // The Cleanup object cleans up partially created cgroups when an error occurs. - // Errors occuring during cleanup itself are ignored. + // Clean up partially created cgroups on error. Errors during cleanup itself + // are ignored. clean := cleanup.Make(func() { _ = c.Uninstall() }) defer clean.Clean() - for key, ctrlr := range controllers { + // Controllers can be symlinks to a group of controllers (e.g. cpu,cpuacct). + // So first check what directories need to be created. Otherwise, when + // the directory for one of the controllers in a group is created, it will + // make it seem like the directory already existed and it's not owned by the + // other controllers in the group. + var missing []string + for key := range controllers { path := c.MakePath(key) - if _, err := os.Stat(path); err == nil { - // If cgroup has already been created; it has been setup by caller. Don't - // make any changes to configuration, just join when sandbox/gofer starts. - log.Debugf("Using pre-created cgroup %q", path) - continue + if _, err := os.Stat(path); err != nil { + missing = append(missing, key) + } else { + log.Debugf("Using pre-created cgroup %q: %q", key, path) } - - // Mark that cgroup resources are owned by me. - c.Own[key] = true - + } + for _, key := range missing { + ctrlr := controllers[key] + path := c.MakePath(key) + log.Debugf("Creating cgroup %q: %q", key, path) if err := os.MkdirAll(path, 0755); err != nil { if ctrlr.optional() && errors.Is(err, unix.EROFS) { if err := ctrlr.skip(res); err != nil { @@ -371,6 +377,9 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error { } return err } + + // Only set controllers that were created by me. + c.Own[key] = true if err := ctrlr.set(res, path); err != nil { return err } diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go index 02cadeef4..eba40621e 100644 --- a/runsc/cgroup/cgroup_test.go +++ b/runsc/cgroup/cgroup_test.go @@ -60,10 +60,10 @@ var dindMountinfo = ` func TestUninstallEnoent(t *testing.T) { c := Cgroup{ - // set a non-existent name + // Use a non-existent name. Name: "runsc-test-uninstall-656e6f656e740a", + Own: make(map[string]bool), } - c.Own = make(map[string]bool) for key := range controllers { c.Own[key] = true } diff --git a/tools/github/reviver/github.go b/tools/github/reviver/github.go index d9773500e..b360f0544 100644 --- a/tools/github/reviver/github.go +++ b/tools/github/reviver/github.go @@ -92,7 +92,7 @@ func (b *GitHubBugger) Activate(todo *Todo) (bool, error) { fmt.Fprintln(&comment, "There are TODOs still referencing this issue:") for _, l := range todo.Locations { fmt.Fprintf(&comment, - "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n", + "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#L%d): %s\n", l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment) } fmt.Fprintf(&comment, |