diff options
Diffstat (limited to 'pkg')
36 files changed, 1149 insertions, 417 deletions
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD new file mode 100644 index 000000000..a77a3beea --- /dev/null +++ b/pkg/buffer/BUILD @@ -0,0 +1,39 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "buffer_list", + out = "buffer_list.go", + package = "buffer", + prefix = "buffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Buffer", + "Linker": "*Buffer", + }, +) + +go_library( + name = "buffer", + srcs = [ + "buffer.go", + "buffer_list.go", + "safemem.go", + "view.go", + "view_unsafe.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/safemem", + ], +) + +go_test( + name = "buffer_test", + size = "small", + srcs = ["view_test.go"], + library = ":buffer", +) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go new file mode 100644 index 000000000..d5f64609b --- /dev/null +++ b/pkg/buffer/buffer.go @@ -0,0 +1,67 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package buffer provides the implementation of a buffer view. +package buffer + +import ( + "sync" +) + +const bufferSize = 8144 // See below. + +// Buffer encapsulates a queueable byte buffer. +// +// Note that the total size is slightly less than two pages. This is done +// intentionally to ensure that the buffer object aligns with runtime +// internals. We have no hard size or alignment requirements. This two page +// size will effectively minimize internal fragmentation, but still have a +// large enough chunk to limit excessive segmentation. +// +// +stateify savable +type Buffer struct { + data [bufferSize]byte + read int + write int + bufferEntry +} + +// Reset resets internal data. +// +// This must be called before use. +func (b *Buffer) Reset() { + b.read = 0 + b.write = 0 +} + +// Empty indicates the buffer is empty. +// +// This indicates there is no data left to read. +func (b *Buffer) Empty() bool { + return b.read == b.write +} + +// Full indicates the buffer is full. +// +// This indicates there is no capacity left to write. +func (b *Buffer) Full() bool { + return b.write == len(b.data) +} + +// bufferPool is a pool for buffers. +var bufferPool = sync.Pool{ + New: func() interface{} { + return new(Buffer) + }, +} diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go new file mode 100644 index 000000000..071aaa488 --- /dev/null +++ b/pkg/buffer/safemem.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "io" + + "gvisor.dev/gvisor/pkg/safemem" +) + +// WriteBlock returns this buffer as a write Block. +func (b *Buffer) WriteBlock() safemem.Block { + return safemem.BlockFromSafeSlice(b.data[b.write:]) +} + +// ReadBlock returns this buffer as a read Block. +func (b *Buffer) ReadBlock() safemem.Block { + return safemem.BlockFromSafeSlice(b.data[b.read:b.write]) +} + +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +// +// This will advance the write index. +func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + need := int(srcs.NumBytes()) + if need == 0 { + return 0, nil + } + + var ( + dst safemem.BlockSeq + blocks []safemem.Block + ) + + // Need at least one buffer. + firstBuf := v.data.Back() + if firstBuf == nil { + firstBuf = bufferPool.Get().(*Buffer) + v.data.PushBack(firstBuf) + } + + // Does the last block have sufficient capacity alone? + if l := len(firstBuf.data) - firstBuf.write; l >= need { + dst = safemem.BlockSeqOf(firstBuf.WriteBlock()) + } else { + // Append blocks until sufficient. + need -= l + blocks = append(blocks, firstBuf.WriteBlock()) + for need > 0 { + emptyBuf := bufferPool.Get().(*Buffer) + v.data.PushBack(emptyBuf) + need -= len(emptyBuf.data) // Full block. + blocks = append(blocks, emptyBuf.WriteBlock()) + } + dst = safemem.BlockSeqFromSlice(blocks) + } + + // Perform the copy. + n, err := safemem.CopySeq(dst, srcs) + v.size += int64(n) + + // Update all indices. + for left := int(n); left > 0; firstBuf = firstBuf.Next() { + if l := len(firstBuf.data) - firstBuf.write; left >= l { + firstBuf.write += l // Whole block. + left -= l + } else { + firstBuf.write += left // Partial block. + left = 0 + } + } + + return n, err +} + +// ReadToBlocks implements safemem.Reader.ReadToBlocks. +// +// This will not advance the read index; the caller should follow +// this call with a call to TrimFront in order to remove the read +// data from the buffer. This is done to support pipe sematics. +func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + need := int(dsts.NumBytes()) + if need == 0 { + return 0, nil + } + + var ( + src safemem.BlockSeq + blocks []safemem.Block + ) + + firstBuf := v.data.Front() + if firstBuf == nil { + return 0, io.EOF + } + + // Is all the data in a single block? + if l := firstBuf.write - firstBuf.read; l >= need { + src = safemem.BlockSeqOf(firstBuf.ReadBlock()) + } else { + // Build a list of all the buffers. + need -= l + blocks = append(blocks, firstBuf.ReadBlock()) + for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() { + need -= buf.write - buf.read + blocks = append(blocks, buf.ReadBlock()) + } + src = safemem.BlockSeqFromSlice(blocks) + } + + // Perform the copy. + n, err := safemem.CopySeq(dsts, src) + + // See above: we would normally advance the read index here, but we + // don't do that in order to support pipe semantics. We rely on a + // separate call to TrimFront() in this case. + + return n, err +} diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go new file mode 100644 index 000000000..00fc11e9c --- /dev/null +++ b/pkg/buffer/view.go @@ -0,0 +1,382 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "fmt" + "io" +) + +// View is a non-linear buffer. +// +// All methods are thread compatible. +// +// +stateify savable +type View struct { + data bufferList + size int64 +} + +// TrimFront removes the first count bytes from the buffer. +func (v *View) TrimFront(count int64) { + if count >= v.size { + v.advanceRead(v.size) + } else { + v.advanceRead(count) + } +} + +// Read implements io.Reader.Read. +// +// Note that reading does not advance the read index. This must be done +// manually using TrimFront or other methods. +func (v *View) Read(p []byte) (int, error) { + return v.ReadAt(p, 0) +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (v *View) ReadAt(p []byte, offset int64) (int, error) { + var ( + skipped int64 + done int64 + ) + for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() { + needToSkip := int(offset - skipped) + if l := buf.write - buf.read; l <= needToSkip { + skipped += int64(l) + continue + } + + // Actually read data. + n := copy(p[done:], buf.data[buf.read+needToSkip:buf.write]) + skipped += int64(needToSkip) + done += int64(n) + } + if int(done) < len(p) { + return int(done), io.EOF + } + return int(done), nil +} + +// Write implements io.Writer.Write. +func (v *View) Write(p []byte) (int, error) { + v.Append(p) // Does not fail. + return len(p), nil +} + +// advanceRead advances the view's read index. +// +// Precondition: there must be sufficient bytes in the buffer. +func (v *View) advanceRead(count int64) { + for buf := v.data.Front(); buf != nil && count > 0; { + l := int64(buf.write - buf.read) + if l > count { + // There is still data for reading. + buf.read += int(count) + v.size -= count + count = 0 + break + } + + // Read from this buffer. + buf.read += int(l) + count -= l + v.size -= l + + // When all data has been read from a buffer, we push + // it into the empty buffer pool for reuse. + oldBuf := buf + buf = buf.Next() // Iterate. + v.data.Remove(oldBuf) + oldBuf.Reset() + bufferPool.Put(oldBuf) + } + if count > 0 { + panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count)) + } +} + +// Truncate truncates the view to the given bytes. +func (v *View) Truncate(length int64) { + if length < 0 || length >= v.size { + return // Nothing to do. + } + for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() { + l := int64(buf.write - buf.read) // Local bytes. + switch { + case v.size-l >= length: + // Drop the buffer completely; see above. + v.data.Remove(buf) + v.size -= l + buf.Reset() + bufferPool.Put(buf) + + case v.size > length && v.size-l < length: + // Just truncate the buffer locally. + delta := (length - (v.size - l)) + buf.write = buf.read + int(delta) + v.size = length + + default: + // Should never happen. + panic("invalid buffer during truncation") + } + } + v.size = length // Save the new size. +} + +// Grow grows the given view to the number of bytes. If zero +// is true, all these bytes will be zero. If zero is false, +// then this is the caller's responsibility. +// +// Precondition: length must be >= 0. +func (v *View) Grow(length int64, zero bool) { + if length < 0 { + panic("negative length provided") + } + for v.size < length { + buf := v.data.Back() + + // Is there at least one buffer? + if buf == nil || buf.Full() { + buf = bufferPool.Get().(*Buffer) + v.data.PushBack(buf) + } + + // Write up to length bytes. + l := len(buf.data) - buf.write + if int64(l) > length-v.size { + l = int(length - v.size) + } + + // Zero the written section; note that this pattern is + // specifically recognized and optimized by the compiler. + if zero { + for i := buf.write; i < buf.write+l; i++ { + buf.data[i] = 0 + } + } + + // Advance the index. + buf.write += l + v.size += int64(l) + } +} + +// Prepend prepends the given data. +func (v *View) Prepend(data []byte) { + // Is there any space in the first buffer? + if buf := v.data.Front(); buf != nil && buf.read > 0 { + // Fill up before the first write. + avail := buf.read + copy(buf.data[0:], data[len(data)-avail:]) + data = data[:len(data)-avail] + v.size += int64(avail) + } + + for len(data) > 0 { + // Do we need an empty buffer? + buf := bufferPool.Get().(*Buffer) + v.data.PushFront(buf) + + // The buffer is empty; copy last chunk. + start := len(data) - len(buf.data) + if start < 0 { + start = 0 // Everything. + } + + // We have to put the data at the end of the current + // buffer in order to ensure that the next prepend will + // correctly fill up the beginning of this buffer. + bStart := len(buf.data) - len(data[start:]) + n := copy(buf.data[bStart:], data[start:]) + buf.read = bStart + buf.write = len(buf.data) + data = data[:start] + v.size += int64(n) + } +} + +// Append appends the given data. +func (v *View) Append(data []byte) { + for done := 0; done < len(data); { + buf := v.data.Back() + + // Find the first empty buffer. + if buf == nil || buf.Full() { + buf = bufferPool.Get().(*Buffer) + v.data.PushBack(buf) + } + + // Copy in to the given buffer. + n := copy(buf.data[buf.write:], data[done:]) + done += n + buf.write += n + v.size += int64(n) + } +} + +// Flatten returns a flattened copy of this data. +// +// This method should not be used in any performance-sensitive paths. It may +// allocate a fresh byte slice sufficiently large to contain all the data in +// the buffer. +// +// N.B. Tee data still belongs to this view, as if there is a single buffer +// present, then it will be returned directly. This should be used for +// temporary use only, and a reference to the given slice should not be held. +func (v *View) Flatten() []byte { + if buf := v.data.Front(); buf.Next() == nil { + return buf.data[buf.read:buf.write] // Only one buffer. + } + data := make([]byte, 0, v.size) // Need to flatten. + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + // Copy to the allocated slice. + data = append(data, buf.data[buf.read:buf.write]...) + } + return data +} + +// Size indicates the total amount of data available in this view. +func (v *View) Size() (sz int64) { + sz = v.size // Pre-calculated. + return sz +} + +// Copy makes a strict copy of this view. +func (v *View) Copy() (other View) { + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + other.Append(buf.data[buf.read:buf.write]) + } + return other +} + +// Apply applies the given function across all valid data. +func (v *View) Apply(fn func([]byte)) { + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + if l := int64(buf.write - buf.read); l > 0 { + fn(buf.data[buf.read:buf.write]) + } + } +} + +// Merge merges the provided View with this one. +// +// The other view will be empty after this operation. +func (v *View) Merge(other *View) { + // Copy over all buffers. + for buf := other.data.Front(); buf != nil && !buf.Empty(); buf = other.data.Front() { + other.data.Remove(buf) + v.data.PushBack(buf) + } + + // Adjust sizes. + v.size += other.size + other.size = 0 +} + +// WriteFromReader writes to the buffer from an io.Reader. +func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { + var ( + done int64 + n int + err error + ) + for done < count { + buf := v.data.Back() + + // Find the first empty buffer. + if buf == nil || buf.Full() { + buf = bufferPool.Get().(*Buffer) + v.data.PushBack(buf) + } + + // Is this less than the minimum batch? + if len(buf.data[buf.write:]) < minBatch && (count-done) >= int64(minBatch) { + tmp := make([]byte, minBatch) + n, err = r.Read(tmp) + v.Write(tmp[:n]) + done += int64(n) + if err != nil { + break + } + continue + } + + // Limit the read, if necessary. + end := len(buf.data) + if int64(end-buf.write) > (count - done) { + end = buf.write + int(count-done) + } + + // Pass the relevant portion of the buffer. + n, err = r.Read(buf.data[buf.write:end]) + buf.write += n + done += int64(n) + v.size += int64(n) + if err == io.EOF { + err = nil // Short write allowed. + break + } else if err != nil { + break + } + } + return done, err +} + +// ReadToWriter reads from the buffer into an io.Writer. +// +// N.B. This does not consume the bytes read. TrimFront should +// be called appropriately after this call in order to do so. +func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { + var ( + done int64 + n int + err error + ) + offset := 0 // Spill-over for batching. + for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() { + l := buf.write - buf.read - offset + + // Is this less than the minimum batch? + if l < minBatch && (count-done) >= int64(minBatch) && (v.size-done) >= int64(minBatch) { + tmp := make([]byte, minBatch) + n, err = v.ReadAt(tmp, done) + w.Write(tmp[:n]) + done += int64(n) + offset = n - l // Reset below. + if err != nil { + break + } + continue + } + + // Limit the write if necessary. + if int64(l) >= (count - done) { + l = int(count - done) + } + + // Perform the actual write. + n, err = w.Write(buf.data[buf.read+offset : buf.read+offset+l]) + done += int64(n) + if err != nil { + break + } + + // Reset spill-over. + offset = 0 + } + return done, err +} diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go new file mode 100644 index 000000000..37e652f16 --- /dev/null +++ b/pkg/buffer/view_test.go @@ -0,0 +1,233 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "bytes" + "strings" + "testing" +) + +func TestView(t *testing.T) { + testCases := []struct { + name string + input string + output string + ops []func(*View) + }{ + // Prepend. + { + name: "prepend", + input: "world", + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte("hello ")) + }, + }, + output: "hello world", + }, + { + name: "prepend fill", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte("0")) + }, + }, + output: "0" + strings.Repeat("1", bufferSize-1), + }, + { + name: "prepend overflow", + input: strings.Repeat("1", bufferSize), + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte("0")) + }, + }, + output: "0" + strings.Repeat("1", bufferSize), + }, + { + name: "prepend multiple buffers", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte(strings.Repeat("0", bufferSize*3))) + }, + }, + output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1), + }, + + // Append. + { + name: "append", + input: "hello", + ops: []func(*View){ + func(v *View) { + v.Append([]byte(" world")) + }, + }, + output: "hello world", + }, + { + name: "append fill", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Append([]byte("0")) + }, + }, + output: strings.Repeat("1", bufferSize-1) + "0", + }, + { + name: "append overflow", + input: strings.Repeat("1", bufferSize), + ops: []func(*View){ + func(v *View) { + v.Append([]byte("0")) + }, + }, + output: strings.Repeat("1", bufferSize) + "0", + }, + { + name: "append multiple buffers", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Append([]byte(strings.Repeat("0", bufferSize*3))) + }, + }, + output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3), + }, + + // Truncate. + { + name: "truncate", + input: "hello world", + ops: []func(*View){ + func(v *View) { + v.Truncate(5) + }, + }, + output: "hello", + }, + { + name: "truncate multiple buffers", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.Truncate(bufferSize*2 - 1) + }, + }, + output: strings.Repeat("1", bufferSize*2-1), + }, + { + name: "truncate multiple buffers to one buffer", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.Truncate(5) + }, + }, + output: "11111", + }, + + // TrimFront. + { + name: "trim", + input: "hello world", + ops: []func(*View){ + func(v *View) { + v.TrimFront(6) + }, + }, + output: "world", + }, + { + name: "trim multiple buffers", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.TrimFront(1) + }, + }, + output: strings.Repeat("1", bufferSize*2-1), + }, + { + name: "trim multiple buffers to one buffer", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.TrimFront(bufferSize*2 - 1) + }, + }, + output: "1", + }, + + // Grow. + { + name: "grow", + input: "hello world", + ops: []func(*View){ + func(v *View) { + v.Grow(1, true) + }, + }, + output: "hello world", + }, + { + name: "grow from zero", + ops: []func(*View){ + func(v *View) { + v.Grow(1024, true) + }, + }, + output: strings.Repeat("\x00", 1024), + }, + { + name: "grow from non-zero", + input: strings.Repeat("1", bufferSize), + ops: []func(*View){ + func(v *View) { + v.Grow(bufferSize*2, true) + }, + }, + output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Construct the new view. + var view View + view.Append([]byte(tc.input)) + + // Run all operations. + for _, op := range tc.ops { + op(&view) + } + + // Flatten and validate. + out := view.Flatten() + if !bytes.Equal([]byte(tc.output), out) { + t.Errorf("expected %q, got %q", tc.output, string(out)) + } + + // Ensure the size is correct. + if len(out) != int(view.Size()) { + t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size()) + } + }) + } +} diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/pkg/buffer/view_unsafe.go index 4d54b8b8f..d1ef39b26 100644 --- a/pkg/sentry/kernel/pipe/buffer_test.go +++ b/pkg/buffer/view_unsafe.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pipe +package buffer import ( - "testing" "unsafe" - - "gvisor.dev/gvisor/pkg/usermem" ) -func TestBufferSize(t *testing.T) { - bufferSize := unsafe.Sizeof(buffer{}) - if bufferSize < usermem.PageSize { - t.Errorf("buffer is less than a page") - } - if bufferSize > (2 * usermem.PageSize) { - t.Errorf("buffer is greater than two pages") - } -} +// minBatch is the smallest Read or Write operation that the +// WriteFromReader and ReadToWriter functions will use. +// +// This is defined as the size of a native pointer. +const minBatch = int(unsafe.Sizeof(uintptr(0))) diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 019caadca..f3a609b57 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -88,8 +88,9 @@ func (l *List) Back() Element { // PushFront inserts the element e at the front of list l. func (l *List) PushFront(e Element) { - ElementMapper{}.linkerFor(e).SetNext(l.head) - ElementMapper{}.linkerFor(e).SetPrev(nil) + linker := ElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) if l.head != nil { ElementMapper{}.linkerFor(l.head).SetPrev(e) @@ -102,8 +103,9 @@ func (l *List) PushFront(e Element) { // PushBack inserts the element e at the back of list l. func (l *List) PushBack(e Element) { - ElementMapper{}.linkerFor(e).SetNext(nil) - ElementMapper{}.linkerFor(e).SetPrev(l.tail) + linker := ElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) if l.tail != nil { ElementMapper{}.linkerFor(l.tail).SetNext(e) @@ -132,10 +134,14 @@ func (l *List) PushBackList(m *List) { // InsertAfter inserts e after b. func (l *List) InsertAfter(b, e Element) { - a := ElementMapper{}.linkerFor(b).Next() - ElementMapper{}.linkerFor(e).SetNext(a) - ElementMapper{}.linkerFor(e).SetPrev(b) - ElementMapper{}.linkerFor(b).SetNext(e) + bLinker := ElementMapper{}.linkerFor(b) + eLinker := ElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) if a != nil { ElementMapper{}.linkerFor(a).SetPrev(e) @@ -146,10 +152,13 @@ func (l *List) InsertAfter(b, e Element) { // InsertBefore inserts e before a. func (l *List) InsertBefore(a, e Element) { - b := ElementMapper{}.linkerFor(a).Prev() - ElementMapper{}.linkerFor(e).SetNext(a) - ElementMapper{}.linkerFor(e).SetPrev(b) - ElementMapper{}.linkerFor(a).SetPrev(e) + aLinker := ElementMapper{}.linkerFor(a) + eLinker := ElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) if b != nil { ElementMapper{}.linkerFor(b).SetNext(e) diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 372b650b9..885115ae2 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -53,6 +53,11 @@ const ( preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5 ) +var ( + // CPUIDInstruction doesn't exist on ARM64. + CPUIDInstruction = []byte{} +) + // These constants are selected as heuristics to help make the Platform's // potentially limited address space conform as closely to Linux as possible. const ( diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 9b6bb26d0..9379a4d7b 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -26,6 +26,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", + "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/memmap", "//pkg/sentry/mm", diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index 7e66c29b0..acbd401a0 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/usermem" ) @@ -124,10 +125,12 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "ptmx": newSymlink(ctx, "pts/ptmx", msrc), "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor), + } - "net": newDirectory(ctx, map[string]*fs.Inode{ + if isNetTunSupported(inet.StackFromContext(ctx)) { + contents["net"] = newDirectory(ctx, map[string]*fs.Inode{ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor), - }, msrc), + }, msrc) } iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 755644488..dc7ad075a 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/syserror" @@ -168,3 +169,9 @@ func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.Eve func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { fops.device.EventUnregister(e) } + +// isNetTunSupported returns whether /dev/net/tun device is supported for s. +func isNetTunSupported(s inet.Stack) bool { + _, ok := s.(*netstack.Stack) + return ok +} diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index daecc4ffe..1922ff08c 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -259,8 +259,8 @@ func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, ui // RemoveXattr implements fs.InodeOperations.RemoveXattr. func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error { - i.mu.RLock() - defer i.mu.RUnlock() + i.mu.Lock() + defer i.mu.Unlock() if _, ok := i.xattrs[name]; ok { delete(i.xattrs, name) return nil diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index d00850e25..c4a8f0b38 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1045,13 +1045,13 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // using the old file descriptor, preventing us from safely // closing it. We could handle this by invalidating existing // memmap.Translations, but this is expensive. Instead, use - // dup2() to make the old file descriptor refer to the new file + // dup3 to make the old file descriptor refer to the new file // description, then close the new file descriptor (which is no // longer needed). Racing callers may use the old or new file // description, but this doesn't matter since they refer to the // same file (unless d.fs.opts.overlayfsStaleRead is true, // which we handle separately). - if err := syscall.Dup2(int(h.fd), int(d.handle.fd)); err != nil { + if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil { d.handleMu.Unlock() ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err) h.close(ctx) diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 54c1031a7..e95209661 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -361,8 +361,15 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro rw.d.handleMu.RLock() if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off) - rw.d.handleMu.RUnlock() rw.off += n + rw.d.dataMu.Lock() + if rw.off > rw.d.size { + atomic.StoreUint64(&rw.d.size, rw.off) + // The remote file's size will implicitly be extended to the correct + // value when we write back to it. + } + rw.d.dataMu.Unlock() + rw.d.handleMu.RUnlock() return n, err } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8b76750e9..1d627564f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -755,6 +755,8 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} { return ctx.k.GlobalInit().Leader().MountNamespaceVFS2() case fs.CtxDirentCacheLimiter: return ctx.k.DirentCacheLimiter + case inet.CtxStack: + return ctx.k.RootNetworkNamespace().Stack() case ktime.CtxRealtimeClock: return ctx.k.RealtimeClock() case limits.CtxLimits: @@ -1481,6 +1483,8 @@ func (ctx supervisorContext) Value(key interface{}) interface{} { return ctx.k.GlobalInit().Leader().MountNamespaceVFS2() case fs.CtxDirentCacheLimiter: return ctx.k.DirentCacheLimiter + case inet.CtxStack: + return ctx.k.RootNetworkNamespace().Stack() case ktime.CtxRealtimeClock: return ctx.k.RealtimeClock() case limits.CtxLimits: diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 4c049d5b4..f29dc0472 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,25 +1,10 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "buffer_list", - out = "buffer_list.go", - package = "pipe", - prefix = "buffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*buffer", - "Linker": "*buffer", - }, -) - go_library( name = "pipe", srcs = [ - "buffer.go", - "buffer_list.go", "device.go", "node.go", "pipe.go", @@ -33,8 +18,8 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/amutex", + "//pkg/buffer", "//pkg/context", - "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", @@ -51,7 +36,6 @@ go_test( name = "pipe_test", size = "small", srcs = [ - "buffer_test.go", "node_test.go", "pipe_test.go", ], diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go deleted file mode 100644 index fe3be5dbd..000000000 --- a/pkg/sentry/kernel/pipe/buffer.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "io" - - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sync" -) - -// buffer encapsulates a queueable byte buffer. -// -// Note that the total size is slightly less than two pages. This -// is done intentionally to ensure that the buffer object aligns -// with runtime internals. We have no hard size or alignment -// requirements. This two page size will effectively minimize -// internal fragmentation, but still have a large enough chunk -// to limit excessive segmentation. -// -// +stateify savable -type buffer struct { - data [8144]byte - read int - write int - bufferEntry -} - -// Reset resets internal data. -// -// This must be called before use. -func (b *buffer) Reset() { - b.read = 0 - b.write = 0 -} - -// Empty indicates the buffer is empty. -// -// This indicates there is no data left to read. -func (b *buffer) Empty() bool { - return b.read == b.write -} - -// Full indicates the buffer is full. -// -// This indicates there is no capacity left to write. -func (b *buffer) Full() bool { - return b.write == len(b.data) -} - -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. -func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.write:])) - n, err := safemem.CopySeq(dst, srcs) - b.write += int(n) - return n, err -} - -// WriteFromReader writes to the buffer from an io.Reader. -func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { - dst := b.data[b.write:] - if count < int64(len(dst)) { - dst = b.data[b.write:][:count] - } - n, err := r.Read(dst) - b.write += n - return int64(n), err -} - -// ReadToBlocks implements safemem.Reader.ReadToBlocks. -func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) - n, err := safemem.CopySeq(dsts, src) - b.read += int(n) - return n, err -} - -// ReadToWriter reads from the buffer into an io.Writer. -func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { - src := b.data[b.read:b.write] - if count < int64(len(src)) { - src = b.data[b.read:][:count] - } - n, err := w.Write(src) - if !dup { - b.read += n - } - return int64(n), err -} - -// bufferPool is a pool for buffers. -var bufferPool = sync.Pool{ - New: func() interface{} { - return new(buffer) - }, -} - -// newBuffer grabs a new buffer from the pool. -func newBuffer() *buffer { - b := bufferPool.Get().(*buffer) - b.Reset() - return b -} diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 08410283f..725e9db7d 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "syscall" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" @@ -70,10 +71,10 @@ type Pipe struct { // mu protects all pipe internal state below. mu sync.Mutex `state:"nosave"` - // data is the buffer queue of pipe contents. + // view is the underlying set of buffers. // // This is protected by mu. - data bufferList + view buffer.View // max is the maximum size of the pipe in bytes. When this max has been // reached, writers will get EWOULDBLOCK. @@ -81,11 +82,6 @@ type Pipe struct { // This is protected by mu. max int64 - // size is the current size of the pipe in bytes. - // - // This is protected by mu. - size int64 - // hadWriter indicates if this pipe ever had a writer. Note that this // does not necessarily indicate there is *currently* a writer, just // that there has been a writer at some point since the pipe was @@ -196,7 +192,7 @@ type readOps struct { limit func(int64) // read performs the actual read operation. - read func(*buffer) (int64, error) + read func(*buffer.View) (int64, error) } // read reads data from the pipe into dst and returns the number of bytes @@ -213,7 +209,7 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { defer p.mu.Unlock() // Is the pipe empty? - if p.size == 0 { + if p.view.Size() == 0 { if !p.HasWriters() { // There are no writers, return EOF. return 0, nil @@ -222,71 +218,13 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { } // Limit how much we consume. - if ops.left() > p.size { - ops.limit(p.size) + if ops.left() > p.view.Size() { + ops.limit(p.view.Size()) } - done := int64(0) - for ops.left() > 0 { - // Pop the first buffer. - first := p.data.Front() - if first == nil { - break - } - - // Copy user data. - n, err := ops.read(first) - done += int64(n) - p.size -= n - - // Empty buffer? - if first.Empty() { - // Push to the free list. - p.data.Remove(first) - bufferPool.Put(first) - } - - // Handle errors. - if err != nil { - return done, err - } - } - - return done, nil -} - -// dup duplicates all data from this pipe into the given writer. -// -// There is no blocking behavior implemented here. The writer may propagate -// some blocking error. All the writes must be complete writes. -func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { - p.mu.Lock() - defer p.mu.Unlock() - - // Is the pipe empty? - if p.size == 0 { - if !p.HasWriters() { - // See above. - return 0, nil - } - return 0, syserror.ErrWouldBlock - } - - // Limit how much we consume. - if ops.left() > p.size { - ops.limit(p.size) - } - - done := int64(0) - for buf := p.data.Front(); buf != nil; buf = buf.Next() { - n, err := ops.read(buf) - done += n - if err != nil { - return done, err - } - } - - return done, nil + // Copy user data; the read op is responsible for trimming. + done, err := ops.read(&p.view) + return done, err } type writeOps struct { @@ -297,7 +235,7 @@ type writeOps struct { limit func(int64) // write should write to the provided buffer. - write func(*buffer) (int64, error) + write func(*buffer.View) (int64, error) } // write writes data from sv into the pipe and returns the number of bytes @@ -317,33 +255,19 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. wanted := ops.left() - if avail := p.max - p.size; wanted > avail { + if avail := p.max - p.view.Size(); wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } ops.limit(avail) } - done := int64(0) - for ops.left() > 0 { - // Need a new buffer? - last := p.data.Back() - if last == nil || last.Full() { - // Add a new buffer to the data list. - last = newBuffer() - p.data.PushBack(last) - } - - // Copy user data. - n, err := ops.write(last) - done += int64(n) - p.size += n - - // Handle errors. - if err != nil { - return done, err - } + // Copy user data. + done, err := ops.write(&p.view) + if err != nil { + return done, err } + if wanted > done { // Partial write due to full pipe. return done, syserror.ErrWouldBlock @@ -396,7 +320,7 @@ func (p *Pipe) HasWriters() bool { // Precondition: mu must be held. func (p *Pipe) rReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasReaders() && p.data.Front() != nil { + if p.HasReaders() && p.view.Size() != 0 { ready |= waiter.EventIn } if !p.HasWriters() && p.hadWriter { @@ -422,7 +346,7 @@ func (p *Pipe) rReadiness() waiter.EventMask { // Precondition: mu must be held. func (p *Pipe) wReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasWriters() && p.size < p.max { + if p.HasWriters() && p.view.Size() < p.max { ready |= waiter.EventOut } if !p.HasReaders() { @@ -451,7 +375,7 @@ func (p *Pipe) rwReadiness() waiter.EventMask { func (p *Pipe) queued() int64 { p.mu.Lock() defer p.mu.Unlock() - return p.size + return p.view.Size() } // FifoSize implements fs.FifoSizer.FifoSize. @@ -474,7 +398,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) { } p.mu.Lock() defer p.mu.Unlock() - if size < p.size { + if size < p.view.Size() { return 0, syserror.EBUSY } p.max = size diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 80158239e..5a1d4fd57 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" @@ -49,9 +50,10 @@ func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) limit: func(l int64) { dst = dst.TakeFirst64(l) }, - read: func(buf *buffer) (int64, error) { - n, err := dst.CopyOutFrom(ctx, buf) + read: func(view *buffer.View) (int64, error) { + n, err := dst.CopyOutFrom(ctx, view) dst = dst.DropFirst64(n) + view.TrimFront(n) return n, err }, }) @@ -70,16 +72,15 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) limit: func(l int64) { count = l }, - read: func(buf *buffer) (int64, error) { - n, err := buf.ReadToWriter(w, count, dup) + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadToWriter(w, count) + if !dup { + view.TrimFront(n) + } count -= n return n, err }, } - if dup { - // There is no notification for dup operations. - return p.dup(ctx, ops) - } n, err := p.read(ctx, ops) if n > 0 { p.Notify(waiter.EventOut) @@ -96,8 +97,8 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) limit: func(l int64) { src = src.TakeFirst64(l) }, - write: func(buf *buffer) (int64, error) { - n, err := src.CopyInTo(ctx, buf) + write: func(view *buffer.View) (int64, error) { + n, err := src.CopyInTo(ctx, view) src = src.DropFirst64(n) return n, err }, @@ -117,8 +118,8 @@ func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, e limit: func(l int64) { count = l }, - write: func(buf *buffer) (int64, error) { - n, err := buf.WriteFromReader(r, count) + write: func(view *buffer.View) (int64, error) { + n, err := view.WriteFromReader(r, count) count -= n return n, err }, diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 5568c91bc..799cbcd93 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -126,13 +126,39 @@ func (t *Task) doStop() { } } +func (*runApp) handleCPUIDInstruction(t *Task) error { + if len(arch.CPUIDInstruction) == 0 { + // CPUID emulation isn't supported, but this code can be + // executed, because the ptrace platform returns + // ErrContextSignalCPUID on page faults too. Look at + // pkg/sentry/platform/ptrace/ptrace.go:context.Switch for more + // details. + return platform.ErrContextSignal + } + // Is this a CPUID instruction? + region := trace.StartRegion(t.traceContext, cpuidRegion) + expected := arch.CPUIDInstruction[:] + found := make([]byte, len(expected)) + _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) + if err == nil && bytes.Equal(expected, found) { + // Skip the cpuid instruction. + t.Arch().CPUIDEmulate(t) + t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) + region.End() + + return nil + } + region.End() // Not an actual CPUID, but required copy-in. + return platform.ErrContextSignal +} + // The runApp state checks for interrupts before executing untrusted // application code. // // +stateify savable type runApp struct{} -func (*runApp) execute(t *Task) taskRunState { +func (app *runApp) execute(t *Task) taskRunState { if t.interrupted() { // Checkpointing instructs tasks to stop by sending an interrupt, so we // must check for stops before entering runInterrupt (instead of @@ -237,21 +263,10 @@ func (*runApp) execute(t *Task) taskRunState { return (*runApp)(nil) case platform.ErrContextSignalCPUID: - // Is this a CPUID instruction? - region := trace.StartRegion(t.traceContext, cpuidRegion) - expected := arch.CPUIDInstruction[:] - found := make([]byte, len(expected)) - _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) - if err == nil && bytes.Equal(expected, found) { - // Skip the cpuid instruction. - t.Arch().CPUIDEmulate(t) - t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) - region.End() - + if err := app.handleCPUIDInstruction(t); err == nil { // Resume execution. return (*runApp)(nil) } - region.End() // Not an actual CPUID, but required copy-in. // The instruction at the given RIP was not a CPUID, and we // fallthrough to the default signal deliver behavior below. diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index e99798c56..cd74945e7 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -21,6 +21,7 @@ import ( "strings" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -183,13 +184,76 @@ func enableCpuidFault() { // appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. // Ref attachedThread() for more detail. -func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { - return append(rules, seccomp.RuleSet{ - Rules: seccomp.SyscallRules{ - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, +func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet { + rules = append(rules, + // Rules for trapping vsyscall access. + seccomp.RuleSet{ + Rules: seccomp.SyscallRules{ + syscall.SYS_GETTIMEOFDAY: {}, + syscall.SYS_TIME: {}, + unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64. }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }) + Action: linux.SECCOMP_RET_TRAP, + Vsyscall: true, + }) + if defaultAction != linux.SECCOMP_RET_ALLOW { + rules = append(rules, + seccomp.RuleSet{ + Rules: seccomp.SyscallRules{ + syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ + {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }) + } + return rules +} + +// probeSeccomp returns true iff seccomp is run after ptrace notifications, +// which is generally the case for kernel version >= 4.8. This check is dynamic +// because kernels have be backported behavior. +// +// See createStub for more information. +// +// Precondition: the runtime OS thread must be locked. +func probeSeccomp() bool { + // Create a completely new, destroyable process. + t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO) + if err != nil { + panic(fmt.Sprintf("seccomp probe failed: %v", err)) + } + defer t.destroy() + + // Set registers to the yield system call. This call is not allowed + // by the filters specified in the attachThread function. + regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD) + if err := t.setRegs(®s); err != nil { + panic(fmt.Sprintf("ptrace set regs failed: %v", err)) + } + + for { + // Attempt an emulation. + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { + panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) + } + + sig := t.wait(stopped) + if sig == (syscallEvent | syscall.SIGTRAP) { + // Did the seccomp errno hook already run? This would + // indicate that seccomp is first in line and we're + // less than 4.8. + if err := t.getRegs(®s); err != nil { + panic(fmt.Sprintf("ptrace get-regs failed: %v", err)) + } + if _, err := syscallReturnValue(®s); err == nil { + // The seccomp errno mode ran first, and reset + // the error in the registers. + return false + } + // The seccomp hook did not run yet, and therefore it + // is safe to use RET_KILL mode for dispatched calls. + return true + } + } } diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index 7b975137f..7f5c393f0 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -160,6 +160,15 @@ func enableCpuidFault() { // appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. // Ref attachedThread() for more detail. -func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { +func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet { return rules } + +// probeSeccomp returns true if seccomp is run after ptrace notifications, +// which is generally the case for kernel version >= 4.8. +// +// On arm64, the support of PTRACE_SYSEMU was added in the 5.3 kernel, so +// probeSeccomp can always return true. +func probeSeccomp() bool { + return true +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index 74968dfdf..2ce528601 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -20,7 +20,6 @@ import ( "fmt" "syscall" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" @@ -30,54 +29,6 @@ import ( const syscallEvent syscall.Signal = 0x80 -// probeSeccomp returns true iff seccomp is run after ptrace notifications, -// which is generally the case for kernel version >= 4.8. This check is dynamic -// because kernels have be backported behavior. -// -// See createStub for more information. -// -// Precondition: the runtime OS thread must be locked. -func probeSeccomp() bool { - // Create a completely new, destroyable process. - t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO) - if err != nil { - panic(fmt.Sprintf("seccomp probe failed: %v", err)) - } - defer t.destroy() - - // Set registers to the yield system call. This call is not allowed - // by the filters specified in the attachThread function. - regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD) - if err := t.setRegs(®s); err != nil { - panic(fmt.Sprintf("ptrace set regs failed: %v", err)) - } - - for { - // Attempt an emulation. - if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { - panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) - } - - sig := t.wait(stopped) - if sig == (syscallEvent | syscall.SIGTRAP) { - // Did the seccomp errno hook already run? This would - // indicate that seccomp is first in line and we're - // less than 4.8. - if err := t.getRegs(®s); err != nil { - panic(fmt.Sprintf("ptrace get-regs failed: %v", err)) - } - if _, err := syscallReturnValue(®s); err == nil { - // The seccomp errno mode ran first, and reset - // the error in the registers. - return false - } - // The seccomp hook did not run yet, and therefore it - // is safe to use RET_KILL mode for dispatched calls. - return true - } - } -} - // createStub creates a fresh stub processes. // // Precondition: the runtime OS thread must be locked. @@ -123,18 +74,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // stub and all its children. This is used to create child stubs // (below), so we must include the ability to fork, but otherwise lock // down available calls only to what is needed. - rules := []seccomp.RuleSet{ - // Rules for trapping vsyscall access. - { - Rules: seccomp.SyscallRules{ - syscall.SYS_GETTIMEOFDAY: {}, - syscall.SYS_TIME: {}, - unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64. - }, - Action: linux.SECCOMP_RET_TRAP, - Vsyscall: true, - }, - } + rules := []seccomp.RuleSet{} if defaultAction != linux.SECCOMP_RET_ALLOW { rules = append(rules, seccomp.RuleSet{ Rules: seccomp.SyscallRules{ @@ -173,9 +113,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }, Action: linux.SECCOMP_RET_ALLOW, }) - - rules = appendArchSeccompRules(rules) } + rules = appendArchSeccompRules(rules, defaultAction) instrs, err := seccomp.BuildProgram(rules, defaultAction) if err != nil { return nil, err diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 4f2406ce3..581841555 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -80,7 +80,7 @@ go_library( "pagetables_amd64.go", "pagetables_arm64.go", "pagetables_x86.go", - "pcids_x86.go", + "pcids.go", "walker_amd64.go", "walker_arm64.go", "walker_empty.go", diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids.go index e199bae18..9206030bf 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/sentry/platform/ring0/pagetables/pcids.go @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build i386 amd64 - package pagetables import ( diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 48c268bfa..13a9a60b4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -712,6 +712,10 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + if len(sockaddr) < 2 { + return syserr.ErrInvalidArgument + } + family := usermem.ByteOrder.Uint16(sockaddr) var addr tcpip.FullAddress @@ -2663,7 +2667,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } // Add bytes removed from the endpoint but not yet sent to the caller. + s.readMu.Lock() v += len(s.readView) + s.readMu.Unlock() if v > math.MaxInt32 { v = math.MaxInt32 diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index bfb2fac26..f7d6009a0 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -221,7 +221,7 @@ func (w *Watchdog) waitForStart() { return } var buf bytes.Buffer - buf.WriteString("Watchdog.Start() not called within %s:\n") + buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout)) w.doAction(w.StartupTimeoutAction, false, &buf) } @@ -325,7 +325,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo func (w *Watchdog) reportStuckWatchdog() { var buf bytes.Buffer - buf.WriteString("Watchdog goroutine is stuck:\n") + buf.WriteString("Watchdog goroutine is stuck:") w.doAction(w.TaskTimeoutAction, false, &buf) } @@ -359,7 +359,7 @@ func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { case <-metricsEmitted: case <-time.After(1 * time.Second): } - panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String())) + panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String())) default: panic(fmt.Sprintf("Unknown watchdog action %v", action)) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 46d3a6646..3e6196aee 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -451,7 +451,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) for _, r := range primaryAddrs { // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoing() { + if !r.isValidForOutgoingRLocked() { continue } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ebb6c5e3b..13354d884 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -551,11 +551,13 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } -// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address -// and returns the network protocol number to be used to communicate with the -// specified address. It returns an error if the passed address is incompatible -// with the receiver. -func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6 +// address and returns the network protocol number to be used to communicate +// with the specified address. It returns an error if the passed address is +// incompatible with the receiver. +// +// Preconditon: the parent endpoint mu must be held while calling this method. +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { netProto := e.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 426da1ee6..2a396e9bc 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,15 +291,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - toCopy := *to - to = &toCopy - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - // Find the enpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + // Find the endpoint. + r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -480,13 +478,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -517,7 +516,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -630,7 +629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cd247f3e1..ae4f3f3a9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() ttl := h.ep.ttl + amss := h.ep.amss h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ @@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // permits SACK. This is not explicitly defined in the RFC but // this is the behaviour implemented by Linux. SACKPermitted: rcvSynOpts.SACKPermitted, - MSS: h.ep.amss, + MSS: amss, } if ttl == 0 { ttl = s.route.DefaultTTL() @@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + h.ep.mu.RLock() + amss := h.ep.amss + h.ep.mu.RUnlock() + h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, - MSS: h.ep.amss, + MSS: amss, } h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. + h.ep.mu.Lock() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ @@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error { SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } + h.ep.mu.Unlock() // Execute is also called in a listen context so we want to make sure we // only send the TS/SACK option when we received the TS/SACK in the diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9e72730bd..40cc664c0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int { // ModerateRecvBuf adjusts the receive buffer and the advertised window // based on the number of bytes copied to user space. func (e *endpoint) ModerateRecvBuf(copied int) { + e.mu.RLock() e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() + e.mu.RUnlock() return } now := time.Now() if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { e.rcvAutoParams.copied += copied e.rcvListMu.Unlock() + e.mu.RUnlock() return } prevRTTCopied := e.rcvAutoParams.copied + copied @@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvBufSize = rcvWnd availAfter := e.receiveBufferAvailableLocked() mask := uint32(notifyReceiveWindowChanged) - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.notifyProtocolGoroutine(mask) @@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvAutoParams.measureTime = now e.rcvAutoParams.copied = 0 e.rcvListMu.Unlock() + e.mu.RUnlock() } // IPTables implements tcpip.Endpoint.IPTables. @@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() if err == tcpip.ErrClosedForReceive { @@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// windowCrossedACKThreshold checks if the receive window to be announced now -// would be under aMSS or under half receive buffer, whichever smaller. This is -// useful as a receive side silly window syndrome prevention mechanism. If +// windowCrossedACKThresholdLocked checks if the receive window to be announced +// now would be under aMSS or under half receive buffer, whichever smaller. This +// is useful as a receive side silly window syndrome prevention mechanism. If // window grows to reasonable value, we should send ACK to the sender to inform // the rx space is now large. We also want ensure a series of small read()'s // won't trigger a flood of spurious tiny ACK's. @@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // crossed will be true if the window size crossed the ACK threshold. // above will be true if the new window is >= ACK threshold and false // otherwise. -func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) { +// +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { newAvail := e.receiveBufferAvailableLocked() oldAvail := newAvail - deltaBefore if oldAvail < 0 { @@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { mask := uint32(notifyReceiveWindowChanged) + e.mu.RLock() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Immediately send an ACK to uncork the sender silly window // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.notifyProtocolGoroutine(mask) return nil @@ -1868,13 +1874,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1904,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2270,7 +2277,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2414,13 +2421,14 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { + e.mu.RLock() e.rcvListMu.Lock() if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() // Increase counter if the receive window falls down below MSS // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above { + if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } e.rcvList.PushBack(s) @@ -2428,7 +2436,7 @@ func (e *endpoint) readyToRead(s *segment) { e.rcvClosed = true } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventIn) } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 958f03ac1..d80aff1b6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -195,6 +195,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum for i := first; i < len(r.pendingRcvdSegments); i++ { r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of + // truncated items, thus truncated items must be set to nil to avoid + // memory leaks. + r.pendingRcvdSegments[i] = nil } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go index 9fd061d7d..e28f213ba 100644 --- a/pkg/tcpip/transport/tcp/segment_heap.go +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -41,6 +41,7 @@ func (h *segmentHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = nil *h = old[:n-1] return x } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1c6a600b8..0af4514e1 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -443,19 +443,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - r, _, err := e.connectRoute(nicID, *to, netProto) + r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { return 0, nil, err } defer r.Release() route = &r - dstPort = to.Port + dstPort = dst.Port } if route.IsResolutionRequired() { @@ -566,7 +566,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa) + fa, netProto, err := e.checkV4MappedLocked(fa) if err != nil { return err } @@ -927,13 +927,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -981,10 +982,6 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr) - if err != nil { - return err - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -1012,6 +1009,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err @@ -1139,7 +1141,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 43fb047ed..466bd9381 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + e.stack = s for _, m := range e.multicastMemberships { |