summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/buffer/BUILD39
-rw-r--r--pkg/buffer/buffer.go67
-rw-r--r--pkg/buffer/safemem.go131
-rw-r--r--pkg/buffer/view.go382
-rw-r--r--pkg/buffer/view_test.go233
-rw-r--r--pkg/buffer/view_unsafe.go (renamed from pkg/sentry/kernel/pipe/buffer_test.go)21
-rw-r--r--pkg/sentry/arch/arch_arm64.go5
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/dev.go7
-rw-r--r--pkg/sentry/fs/dev/net_tun.go7
-rw-r--r--pkg/sentry/fs/fsutil/inode.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go4
-rw-r--r--pkg/sentry/kernel/kernel.go4
-rw-r--r--pkg/sentry/kernel/pipe/BUILD18
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go115
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go118
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go25
-rw-r--r--pkg/sentry/kernel/task_run.go41
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids.go (renamed from pkg/sentry/platform/ring0/pagetables/pcids_x86.go)2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go6
-rw-r--r--pkg/sentry/watchdog/watchdog.go6
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go23
-rw-r--r--pkg/tcpip/transport/tcp/connect.go11
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go44
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go3
29 files changed, 1031 insertions, 332 deletions
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD
new file mode 100644
index 000000000..a77a3beea
--- /dev/null
+++ b/pkg/buffer/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "buffer_list",
+ out = "buffer_list.go",
+ package = "buffer",
+ prefix = "buffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Buffer",
+ "Linker": "*Buffer",
+ },
+)
+
+go_library(
+ name = "buffer",
+ srcs = [
+ "buffer.go",
+ "buffer_list.go",
+ "safemem.go",
+ "view.go",
+ "view_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/safemem",
+ ],
+)
+
+go_test(
+ name = "buffer_test",
+ size = "small",
+ srcs = ["view_test.go"],
+ library = ":buffer",
+)
diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go
new file mode 100644
index 000000000..d5f64609b
--- /dev/null
+++ b/pkg/buffer/buffer.go
@@ -0,0 +1,67 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+package buffer
+
+import (
+ "sync"
+)
+
+const bufferSize = 8144 // See below.
+
+// Buffer encapsulates a queueable byte buffer.
+//
+// Note that the total size is slightly less than two pages. This is done
+// intentionally to ensure that the buffer object aligns with runtime
+// internals. We have no hard size or alignment requirements. This two page
+// size will effectively minimize internal fragmentation, but still have a
+// large enough chunk to limit excessive segmentation.
+//
+// +stateify savable
+type Buffer struct {
+ data [bufferSize]byte
+ read int
+ write int
+ bufferEntry
+}
+
+// Reset resets internal data.
+//
+// This must be called before use.
+func (b *Buffer) Reset() {
+ b.read = 0
+ b.write = 0
+}
+
+// Empty indicates the buffer is empty.
+//
+// This indicates there is no data left to read.
+func (b *Buffer) Empty() bool {
+ return b.read == b.write
+}
+
+// Full indicates the buffer is full.
+//
+// This indicates there is no capacity left to write.
+func (b *Buffer) Full() bool {
+ return b.write == len(b.data)
+}
+
+// bufferPool is a pool for buffers.
+var bufferPool = sync.Pool{
+ New: func() interface{} {
+ return new(Buffer)
+ },
+}
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
new file mode 100644
index 000000000..071aaa488
--- /dev/null
+++ b/pkg/buffer/safemem.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// WriteBlock returns this buffer as a write Block.
+func (b *Buffer) WriteBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.data[b.write:])
+}
+
+// ReadBlock returns this buffer as a read Block.
+func (b *Buffer) ReadBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.data[b.read:b.write])
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// This will advance the write index.
+func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ need := int(srcs.NumBytes())
+ if need == 0 {
+ return 0, nil
+ }
+
+ var (
+ dst safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ // Need at least one buffer.
+ firstBuf := v.data.Back()
+ if firstBuf == nil {
+ firstBuf = bufferPool.Get().(*Buffer)
+ v.data.PushBack(firstBuf)
+ }
+
+ // Does the last block have sufficient capacity alone?
+ if l := len(firstBuf.data) - firstBuf.write; l >= need {
+ dst = safemem.BlockSeqOf(firstBuf.WriteBlock())
+ } else {
+ // Append blocks until sufficient.
+ need -= l
+ blocks = append(blocks, firstBuf.WriteBlock())
+ for need > 0 {
+ emptyBuf := bufferPool.Get().(*Buffer)
+ v.data.PushBack(emptyBuf)
+ need -= len(emptyBuf.data) // Full block.
+ blocks = append(blocks, emptyBuf.WriteBlock())
+ }
+ dst = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform the copy.
+ n, err := safemem.CopySeq(dst, srcs)
+ v.size += int64(n)
+
+ // Update all indices.
+ for left := int(n); left > 0; firstBuf = firstBuf.Next() {
+ if l := len(firstBuf.data) - firstBuf.write; left >= l {
+ firstBuf.write += l // Whole block.
+ left -= l
+ } else {
+ firstBuf.write += left // Partial block.
+ left = 0
+ }
+ }
+
+ return n, err
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+//
+// This will not advance the read index; the caller should follow
+// this call with a call to TrimFront in order to remove the read
+// data from the buffer. This is done to support pipe sematics.
+func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ need := int(dsts.NumBytes())
+ if need == 0 {
+ return 0, nil
+ }
+
+ var (
+ src safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ firstBuf := v.data.Front()
+ if firstBuf == nil {
+ return 0, io.EOF
+ }
+
+ // Is all the data in a single block?
+ if l := firstBuf.write - firstBuf.read; l >= need {
+ src = safemem.BlockSeqOf(firstBuf.ReadBlock())
+ } else {
+ // Build a list of all the buffers.
+ need -= l
+ blocks = append(blocks, firstBuf.ReadBlock())
+ for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() {
+ need -= buf.write - buf.read
+ blocks = append(blocks, buf.ReadBlock())
+ }
+ src = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform the copy.
+ n, err := safemem.CopySeq(dsts, src)
+
+ // See above: we would normally advance the read index here, but we
+ // don't do that in order to support pipe semantics. We rely on a
+ // separate call to TrimFront() in this case.
+
+ return n, err
+}
diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go
new file mode 100644
index 000000000..00fc11e9c
--- /dev/null
+++ b/pkg/buffer/view.go
@@ -0,0 +1,382 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "fmt"
+ "io"
+)
+
+// View is a non-linear buffer.
+//
+// All methods are thread compatible.
+//
+// +stateify savable
+type View struct {
+ data bufferList
+ size int64
+}
+
+// TrimFront removes the first count bytes from the buffer.
+func (v *View) TrimFront(count int64) {
+ if count >= v.size {
+ v.advanceRead(v.size)
+ } else {
+ v.advanceRead(count)
+ }
+}
+
+// Read implements io.Reader.Read.
+//
+// Note that reading does not advance the read index. This must be done
+// manually using TrimFront or other methods.
+func (v *View) Read(p []byte) (int, error) {
+ return v.ReadAt(p, 0)
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (v *View) ReadAt(p []byte, offset int64) (int, error) {
+ var (
+ skipped int64
+ done int64
+ )
+ for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() {
+ needToSkip := int(offset - skipped)
+ if l := buf.write - buf.read; l <= needToSkip {
+ skipped += int64(l)
+ continue
+ }
+
+ // Actually read data.
+ n := copy(p[done:], buf.data[buf.read+needToSkip:buf.write])
+ skipped += int64(needToSkip)
+ done += int64(n)
+ }
+ if int(done) < len(p) {
+ return int(done), io.EOF
+ }
+ return int(done), nil
+}
+
+// Write implements io.Writer.Write.
+func (v *View) Write(p []byte) (int, error) {
+ v.Append(p) // Does not fail.
+ return len(p), nil
+}
+
+// advanceRead advances the view's read index.
+//
+// Precondition: there must be sufficient bytes in the buffer.
+func (v *View) advanceRead(count int64) {
+ for buf := v.data.Front(); buf != nil && count > 0; {
+ l := int64(buf.write - buf.read)
+ if l > count {
+ // There is still data for reading.
+ buf.read += int(count)
+ v.size -= count
+ count = 0
+ break
+ }
+
+ // Read from this buffer.
+ buf.read += int(l)
+ count -= l
+ v.size -= l
+
+ // When all data has been read from a buffer, we push
+ // it into the empty buffer pool for reuse.
+ oldBuf := buf
+ buf = buf.Next() // Iterate.
+ v.data.Remove(oldBuf)
+ oldBuf.Reset()
+ bufferPool.Put(oldBuf)
+ }
+ if count > 0 {
+ panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count))
+ }
+}
+
+// Truncate truncates the view to the given bytes.
+func (v *View) Truncate(length int64) {
+ if length < 0 || length >= v.size {
+ return // Nothing to do.
+ }
+ for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() {
+ l := int64(buf.write - buf.read) // Local bytes.
+ switch {
+ case v.size-l >= length:
+ // Drop the buffer completely; see above.
+ v.data.Remove(buf)
+ v.size -= l
+ buf.Reset()
+ bufferPool.Put(buf)
+
+ case v.size > length && v.size-l < length:
+ // Just truncate the buffer locally.
+ delta := (length - (v.size - l))
+ buf.write = buf.read + int(delta)
+ v.size = length
+
+ default:
+ // Should never happen.
+ panic("invalid buffer during truncation")
+ }
+ }
+ v.size = length // Save the new size.
+}
+
+// Grow grows the given view to the number of bytes. If zero
+// is true, all these bytes will be zero. If zero is false,
+// then this is the caller's responsibility.
+//
+// Precondition: length must be >= 0.
+func (v *View) Grow(length int64, zero bool) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ for v.size < length {
+ buf := v.data.Back()
+
+ // Is there at least one buffer?
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*Buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Write up to length bytes.
+ l := len(buf.data) - buf.write
+ if int64(l) > length-v.size {
+ l = int(length - v.size)
+ }
+
+ // Zero the written section; note that this pattern is
+ // specifically recognized and optimized by the compiler.
+ if zero {
+ for i := buf.write; i < buf.write+l; i++ {
+ buf.data[i] = 0
+ }
+ }
+
+ // Advance the index.
+ buf.write += l
+ v.size += int64(l)
+ }
+}
+
+// Prepend prepends the given data.
+func (v *View) Prepend(data []byte) {
+ // Is there any space in the first buffer?
+ if buf := v.data.Front(); buf != nil && buf.read > 0 {
+ // Fill up before the first write.
+ avail := buf.read
+ copy(buf.data[0:], data[len(data)-avail:])
+ data = data[:len(data)-avail]
+ v.size += int64(avail)
+ }
+
+ for len(data) > 0 {
+ // Do we need an empty buffer?
+ buf := bufferPool.Get().(*Buffer)
+ v.data.PushFront(buf)
+
+ // The buffer is empty; copy last chunk.
+ start := len(data) - len(buf.data)
+ if start < 0 {
+ start = 0 // Everything.
+ }
+
+ // We have to put the data at the end of the current
+ // buffer in order to ensure that the next prepend will
+ // correctly fill up the beginning of this buffer.
+ bStart := len(buf.data) - len(data[start:])
+ n := copy(buf.data[bStart:], data[start:])
+ buf.read = bStart
+ buf.write = len(buf.data)
+ data = data[:start]
+ v.size += int64(n)
+ }
+}
+
+// Append appends the given data.
+func (v *View) Append(data []byte) {
+ for done := 0; done < len(data); {
+ buf := v.data.Back()
+
+ // Find the first empty buffer.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*Buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Copy in to the given buffer.
+ n := copy(buf.data[buf.write:], data[done:])
+ done += n
+ buf.write += n
+ v.size += int64(n)
+ }
+}
+
+// Flatten returns a flattened copy of this data.
+//
+// This method should not be used in any performance-sensitive paths. It may
+// allocate a fresh byte slice sufficiently large to contain all the data in
+// the buffer.
+//
+// N.B. Tee data still belongs to this view, as if there is a single buffer
+// present, then it will be returned directly. This should be used for
+// temporary use only, and a reference to the given slice should not be held.
+func (v *View) Flatten() []byte {
+ if buf := v.data.Front(); buf.Next() == nil {
+ return buf.data[buf.read:buf.write] // Only one buffer.
+ }
+ data := make([]byte, 0, v.size) // Need to flatten.
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ // Copy to the allocated slice.
+ data = append(data, buf.data[buf.read:buf.write]...)
+ }
+ return data
+}
+
+// Size indicates the total amount of data available in this view.
+func (v *View) Size() (sz int64) {
+ sz = v.size // Pre-calculated.
+ return sz
+}
+
+// Copy makes a strict copy of this view.
+func (v *View) Copy() (other View) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ other.Append(buf.data[buf.read:buf.write])
+ }
+ return other
+}
+
+// Apply applies the given function across all valid data.
+func (v *View) Apply(fn func([]byte)) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ if l := int64(buf.write - buf.read); l > 0 {
+ fn(buf.data[buf.read:buf.write])
+ }
+ }
+}
+
+// Merge merges the provided View with this one.
+//
+// The other view will be empty after this operation.
+func (v *View) Merge(other *View) {
+ // Copy over all buffers.
+ for buf := other.data.Front(); buf != nil && !buf.Empty(); buf = other.data.Front() {
+ other.data.Remove(buf)
+ v.data.PushBack(buf)
+ }
+
+ // Adjust sizes.
+ v.size += other.size
+ other.size = 0
+}
+
+// WriteFromReader writes to the buffer from an io.Reader.
+func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ for done < count {
+ buf := v.data.Back()
+
+ // Find the first empty buffer.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*Buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Is this less than the minimum batch?
+ if len(buf.data[buf.write:]) < minBatch && (count-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = r.Read(tmp)
+ v.Write(tmp[:n])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the read, if necessary.
+ end := len(buf.data)
+ if int64(end-buf.write) > (count - done) {
+ end = buf.write + int(count-done)
+ }
+
+ // Pass the relevant portion of the buffer.
+ n, err = r.Read(buf.data[buf.write:end])
+ buf.write += n
+ done += int64(n)
+ v.size += int64(n)
+ if err == io.EOF {
+ err = nil // Short write allowed.
+ break
+ } else if err != nil {
+ break
+ }
+ }
+ return done, err
+}
+
+// ReadToWriter reads from the buffer into an io.Writer.
+//
+// N.B. This does not consume the bytes read. TrimFront should
+// be called appropriately after this call in order to do so.
+func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ offset := 0 // Spill-over for batching.
+ for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() {
+ l := buf.write - buf.read - offset
+
+ // Is this less than the minimum batch?
+ if l < minBatch && (count-done) >= int64(minBatch) && (v.size-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = v.ReadAt(tmp, done)
+ w.Write(tmp[:n])
+ done += int64(n)
+ offset = n - l // Reset below.
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the write if necessary.
+ if int64(l) >= (count - done) {
+ l = int(count - done)
+ }
+
+ // Perform the actual write.
+ n, err = w.Write(buf.data[buf.read+offset : buf.read+offset+l])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+
+ // Reset spill-over.
+ offset = 0
+ }
+ return done, err
+}
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
new file mode 100644
index 000000000..37e652f16
--- /dev/null
+++ b/pkg/buffer/view_test.go
@@ -0,0 +1,233 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+)
+
+func TestView(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ ops []func(*View)
+ }{
+ // Prepend.
+ {
+ name: "prepend",
+ input: "world",
+ ops: []func(*View){
+ func(v *View) {
+ v.Prepend([]byte("hello "))
+ },
+ },
+ output: "hello world",
+ },
+ {
+ name: "prepend fill",
+ input: strings.Repeat("1", bufferSize-1),
+ ops: []func(*View){
+ func(v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ },
+ {
+ name: "prepend overflow",
+ input: strings.Repeat("1", bufferSize),
+ ops: []func(*View){
+ func(v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ output: "0" + strings.Repeat("1", bufferSize),
+ },
+ {
+ name: "prepend multiple buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ ops: []func(*View){
+ func(v *View) {
+ v.Prepend([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+ output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1),
+ },
+
+ // Append.
+ {
+ name: "append",
+ input: "hello",
+ ops: []func(*View){
+ func(v *View) {
+ v.Append([]byte(" world"))
+ },
+ },
+ output: "hello world",
+ },
+ {
+ name: "append fill",
+ input: strings.Repeat("1", bufferSize-1),
+ ops: []func(*View){
+ func(v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ },
+ {
+ name: "append overflow",
+ input: strings.Repeat("1", bufferSize),
+ ops: []func(*View){
+ func(v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ output: strings.Repeat("1", bufferSize) + "0",
+ },
+ {
+ name: "append multiple buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ ops: []func(*View){
+ func(v *View) {
+ v.Append([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+ output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3),
+ },
+
+ // Truncate.
+ {
+ name: "truncate",
+ input: "hello world",
+ ops: []func(*View){
+ func(v *View) {
+ v.Truncate(5)
+ },
+ },
+ output: "hello",
+ },
+ {
+ name: "truncate multiple buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ ops: []func(*View){
+ func(v *View) {
+ v.Truncate(bufferSize*2 - 1)
+ },
+ },
+ output: strings.Repeat("1", bufferSize*2-1),
+ },
+ {
+ name: "truncate multiple buffers to one buffer",
+ input: strings.Repeat("1", bufferSize*2),
+ ops: []func(*View){
+ func(v *View) {
+ v.Truncate(5)
+ },
+ },
+ output: "11111",
+ },
+
+ // TrimFront.
+ {
+ name: "trim",
+ input: "hello world",
+ ops: []func(*View){
+ func(v *View) {
+ v.TrimFront(6)
+ },
+ },
+ output: "world",
+ },
+ {
+ name: "trim multiple buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ ops: []func(*View){
+ func(v *View) {
+ v.TrimFront(1)
+ },
+ },
+ output: strings.Repeat("1", bufferSize*2-1),
+ },
+ {
+ name: "trim multiple buffers to one buffer",
+ input: strings.Repeat("1", bufferSize*2),
+ ops: []func(*View){
+ func(v *View) {
+ v.TrimFront(bufferSize*2 - 1)
+ },
+ },
+ output: "1",
+ },
+
+ // Grow.
+ {
+ name: "grow",
+ input: "hello world",
+ ops: []func(*View){
+ func(v *View) {
+ v.Grow(1, true)
+ },
+ },
+ output: "hello world",
+ },
+ {
+ name: "grow from zero",
+ ops: []func(*View){
+ func(v *View) {
+ v.Grow(1024, true)
+ },
+ },
+ output: strings.Repeat("\x00", 1024),
+ },
+ {
+ name: "grow from non-zero",
+ input: strings.Repeat("1", bufferSize),
+ ops: []func(*View){
+ func(v *View) {
+ v.Grow(bufferSize*2, true)
+ },
+ },
+ output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Construct the new view.
+ var view View
+ view.Append([]byte(tc.input))
+
+ // Run all operations.
+ for _, op := range tc.ops {
+ op(&view)
+ }
+
+ // Flatten and validate.
+ out := view.Flatten()
+ if !bytes.Equal([]byte(tc.output), out) {
+ t.Errorf("expected %q, got %q", tc.output, string(out))
+ }
+
+ // Ensure the size is correct.
+ if len(out) != int(view.Size()) {
+ t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size())
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/pkg/buffer/view_unsafe.go
index 4d54b8b8f..d1ef39b26 100644
--- a/pkg/sentry/kernel/pipe/buffer_test.go
+++ b/pkg/buffer/view_unsafe.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,21 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package pipe
+package buffer
import (
- "testing"
"unsafe"
-
- "gvisor.dev/gvisor/pkg/usermem"
)
-func TestBufferSize(t *testing.T) {
- bufferSize := unsafe.Sizeof(buffer{})
- if bufferSize < usermem.PageSize {
- t.Errorf("buffer is less than a page")
- }
- if bufferSize > (2 * usermem.PageSize) {
- t.Errorf("buffer is greater than two pages")
- }
-}
+// minBatch is the smallest Read or Write operation that the
+// WriteFromReader and ReadToWriter functions will use.
+//
+// This is defined as the size of a native pointer.
+const minBatch = int(unsafe.Sizeof(uintptr(0)))
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index 372b650b9..885115ae2 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -53,6 +53,11 @@ const (
preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
)
+var (
+ // CPUIDInstruction doesn't exist on ARM64.
+ CPUIDInstruction = []byte{}
+)
+
// These constants are selected as heuristics to help make the Platform's
// potentially limited address space conform as closely to Linux as possible.
const (
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9b6bb26d0..9379a4d7b 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -26,6 +26,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 7e66c29b0..acbd401a0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -124,10 +125,12 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+ }
- "net": newDirectory(ctx, map[string]*fs.Inode{
+ if isNetTunSupported(inet.StackFromContext(ctx)) {
+ contents["net"] = newDirectory(ctx, map[string]*fs.Inode{
"tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
- }, msrc),
+ }, msrc)
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 755644488..dc7ad075a 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/syserror"
@@ -168,3 +169,9 @@ func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.Eve
func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
fops.device.EventUnregister(e)
}
+
+// isNetTunSupported returns whether /dev/net/tun device is supported for s.
+func isNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index daecc4ffe..1922ff08c 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -259,8 +259,8 @@ func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, ui
// RemoveXattr implements fs.InodeOperations.RemoveXattr.
func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error {
- i.mu.RLock()
- defer i.mu.RUnlock()
+ i.mu.Lock()
+ defer i.mu.Unlock()
if _, ok := i.xattrs[name]; ok {
delete(i.xattrs, name)
return nil
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index d00850e25..c4a8f0b38 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1045,13 +1045,13 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// using the old file descriptor, preventing us from safely
// closing it. We could handle this by invalidating existing
// memmap.Translations, but this is expensive. Instead, use
- // dup2() to make the old file descriptor refer to the new file
+ // dup3 to make the old file descriptor refer to the new file
// description, then close the new file descriptor (which is no
// longer needed). Racing callers may use the old or new file
// description, but this doesn't matter since they refer to the
// same file (unless d.fs.opts.overlayfsStaleRead is true,
// which we handle separately).
- if err := syscall.Dup2(int(h.fd), int(d.handle.fd)); err != nil {
+ if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil {
d.handleMu.Unlock()
ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err)
h.close(ctx)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8b76750e9..1d627564f 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -755,6 +755,8 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
@@ -1481,6 +1483,8 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 4c049d5b4..f29dc0472 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,25 +1,10 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-go_template_instance(
- name = "buffer_list",
- out = "buffer_list.go",
- package = "pipe",
- prefix = "buffer",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*buffer",
- "Linker": "*buffer",
- },
-)
-
go_library(
name = "pipe",
srcs = [
- "buffer.go",
- "buffer_list.go",
"device.go",
"node.go",
"pipe.go",
@@ -33,8 +18,8 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
+ "//pkg/buffer",
"//pkg/context",
- "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
@@ -51,7 +36,6 @@ go_test(
name = "pipe_test",
size = "small",
srcs = [
- "buffer_test.go",
"node_test.go",
"pipe_test.go",
],
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
deleted file mode 100644
index fe3be5dbd..000000000
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package pipe
-
-import (
- "io"
-
- "gvisor.dev/gvisor/pkg/safemem"
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-// buffer encapsulates a queueable byte buffer.
-//
-// Note that the total size is slightly less than two pages. This
-// is done intentionally to ensure that the buffer object aligns
-// with runtime internals. We have no hard size or alignment
-// requirements. This two page size will effectively minimize
-// internal fragmentation, but still have a large enough chunk
-// to limit excessive segmentation.
-//
-// +stateify savable
-type buffer struct {
- data [8144]byte
- read int
- write int
- bufferEntry
-}
-
-// Reset resets internal data.
-//
-// This must be called before use.
-func (b *buffer) Reset() {
- b.read = 0
- b.write = 0
-}
-
-// Empty indicates the buffer is empty.
-//
-// This indicates there is no data left to read.
-func (b *buffer) Empty() bool {
- return b.read == b.write
-}
-
-// Full indicates the buffer is full.
-//
-// This indicates there is no capacity left to write.
-func (b *buffer) Full() bool {
- return b.write == len(b.data)
-}
-
-// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
-func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.write:]))
- n, err := safemem.CopySeq(dst, srcs)
- b.write += int(n)
- return n, err
-}
-
-// WriteFromReader writes to the buffer from an io.Reader.
-func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
- dst := b.data[b.write:]
- if count < int64(len(dst)) {
- dst = b.data[b.write:][:count]
- }
- n, err := r.Read(dst)
- b.write += n
- return int64(n), err
-}
-
-// ReadToBlocks implements safemem.Reader.ReadToBlocks.
-func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
- n, err := safemem.CopySeq(dsts, src)
- b.read += int(n)
- return n, err
-}
-
-// ReadToWriter reads from the buffer into an io.Writer.
-func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
- src := b.data[b.read:b.write]
- if count < int64(len(src)) {
- src = b.data[b.read:][:count]
- }
- n, err := w.Write(src)
- if !dup {
- b.read += n
- }
- return int64(n), err
-}
-
-// bufferPool is a pool for buffers.
-var bufferPool = sync.Pool{
- New: func() interface{} {
- return new(buffer)
- },
-}
-
-// newBuffer grabs a new buffer from the pool.
-func newBuffer() *buffer {
- b := bufferPool.Get().(*buffer)
- b.Reset()
- return b
-}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 08410283f..725e9db7d 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -20,6 +20,7 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
@@ -70,10 +71,10 @@ type Pipe struct {
// mu protects all pipe internal state below.
mu sync.Mutex `state:"nosave"`
- // data is the buffer queue of pipe contents.
+ // view is the underlying set of buffers.
//
// This is protected by mu.
- data bufferList
+ view buffer.View
// max is the maximum size of the pipe in bytes. When this max has been
// reached, writers will get EWOULDBLOCK.
@@ -81,11 +82,6 @@ type Pipe struct {
// This is protected by mu.
max int64
- // size is the current size of the pipe in bytes.
- //
- // This is protected by mu.
- size int64
-
// hadWriter indicates if this pipe ever had a writer. Note that this
// does not necessarily indicate there is *currently* a writer, just
// that there has been a writer at some point since the pipe was
@@ -196,7 +192,7 @@ type readOps struct {
limit func(int64)
// read performs the actual read operation.
- read func(*buffer) (int64, error)
+ read func(*buffer.View) (int64, error)
}
// read reads data from the pipe into dst and returns the number of bytes
@@ -213,7 +209,7 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
defer p.mu.Unlock()
// Is the pipe empty?
- if p.size == 0 {
+ if p.view.Size() == 0 {
if !p.HasWriters() {
// There are no writers, return EOF.
return 0, nil
@@ -222,71 +218,13 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
}
// Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
+ if ops.left() > p.view.Size() {
+ ops.limit(p.view.Size())
}
- done := int64(0)
- for ops.left() > 0 {
- // Pop the first buffer.
- first := p.data.Front()
- if first == nil {
- break
- }
-
- // Copy user data.
- n, err := ops.read(first)
- done += int64(n)
- p.size -= n
-
- // Empty buffer?
- if first.Empty() {
- // Push to the free list.
- p.data.Remove(first)
- bufferPool.Put(first)
- }
-
- // Handle errors.
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
-}
-
-// dup duplicates all data from this pipe into the given writer.
-//
-// There is no blocking behavior implemented here. The writer may propagate
-// some blocking error. All the writes must be complete writes.
-func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- // Is the pipe empty?
- if p.size == 0 {
- if !p.HasWriters() {
- // See above.
- return 0, nil
- }
- return 0, syserror.ErrWouldBlock
- }
-
- // Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
- }
-
- done := int64(0)
- for buf := p.data.Front(); buf != nil; buf = buf.Next() {
- n, err := ops.read(buf)
- done += n
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
+ // Copy user data; the read op is responsible for trimming.
+ done, err := ops.read(&p.view)
+ return done, err
}
type writeOps struct {
@@ -297,7 +235,7 @@ type writeOps struct {
limit func(int64)
// write should write to the provided buffer.
- write func(*buffer) (int64, error)
+ write func(*buffer.View) (int64, error)
}
// write writes data from sv into the pipe and returns the number of bytes
@@ -317,33 +255,19 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
wanted := ops.left()
- if avail := p.max - p.size; wanted > avail {
+ if avail := p.max - p.view.Size(); wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
ops.limit(avail)
}
- done := int64(0)
- for ops.left() > 0 {
- // Need a new buffer?
- last := p.data.Back()
- if last == nil || last.Full() {
- // Add a new buffer to the data list.
- last = newBuffer()
- p.data.PushBack(last)
- }
-
- // Copy user data.
- n, err := ops.write(last)
- done += int64(n)
- p.size += n
-
- // Handle errors.
- if err != nil {
- return done, err
- }
+ // Copy user data.
+ done, err := ops.write(&p.view)
+ if err != nil {
+ return done, err
}
+
if wanted > done {
// Partial write due to full pipe.
return done, syserror.ErrWouldBlock
@@ -396,7 +320,7 @@ func (p *Pipe) HasWriters() bool {
// Precondition: mu must be held.
func (p *Pipe) rReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasReaders() && p.data.Front() != nil {
+ if p.HasReaders() && p.view.Size() != 0 {
ready |= waiter.EventIn
}
if !p.HasWriters() && p.hadWriter {
@@ -422,7 +346,7 @@ func (p *Pipe) rReadiness() waiter.EventMask {
// Precondition: mu must be held.
func (p *Pipe) wReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasWriters() && p.size < p.max {
+ if p.HasWriters() && p.view.Size() < p.max {
ready |= waiter.EventOut
}
if !p.HasReaders() {
@@ -451,7 +375,7 @@ func (p *Pipe) rwReadiness() waiter.EventMask {
func (p *Pipe) queued() int64 {
p.mu.Lock()
defer p.mu.Unlock()
- return p.size
+ return p.view.Size()
}
// FifoSize implements fs.FifoSizer.FifoSize.
@@ -474,7 +398,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) {
}
p.mu.Lock()
defer p.mu.Unlock()
- if size < p.size {
+ if size < p.view.Size() {
return 0, syserror.EBUSY
}
p.max = size
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 80158239e..5a1d4fd57 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
@@ -49,9 +50,10 @@ func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error)
limit: func(l int64) {
dst = dst.TakeFirst64(l)
},
- read: func(buf *buffer) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, buf)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, view)
dst = dst.DropFirst64(n)
+ view.TrimFront(n)
return n, err
},
})
@@ -70,16 +72,15 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool)
limit: func(l int64) {
count = l
},
- read: func(buf *buffer) (int64, error) {
- n, err := buf.ReadToWriter(w, count, dup)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToWriter(w, count)
+ if !dup {
+ view.TrimFront(n)
+ }
count -= n
return n, err
},
}
- if dup {
- // There is no notification for dup operations.
- return p.dup(ctx, ops)
- }
n, err := p.read(ctx, ops)
if n > 0 {
p.Notify(waiter.EventOut)
@@ -96,8 +97,8 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error)
limit: func(l int64) {
src = src.TakeFirst64(l)
},
- write: func(buf *buffer) (int64, error) {
- n, err := src.CopyInTo(ctx, buf)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := src.CopyInTo(ctx, view)
src = src.DropFirst64(n)
return n, err
},
@@ -117,8 +118,8 @@ func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, e
limit: func(l int64) {
count = l
},
- write: func(buf *buffer) (int64, error) {
- n, err := buf.WriteFromReader(r, count)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromReader(r, count)
count -= n
return n, err
},
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index 5568c91bc..799cbcd93 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -126,13 +126,39 @@ func (t *Task) doStop() {
}
}
+func (*runApp) handleCPUIDInstruction(t *Task) error {
+ if len(arch.CPUIDInstruction) == 0 {
+ // CPUID emulation isn't supported, but this code can be
+ // executed, because the ptrace platform returns
+ // ErrContextSignalCPUID on page faults too. Look at
+ // pkg/sentry/platform/ptrace/ptrace.go:context.Switch for more
+ // details.
+ return platform.ErrContextSignal
+ }
+ // Is this a CPUID instruction?
+ region := trace.StartRegion(t.traceContext, cpuidRegion)
+ expected := arch.CPUIDInstruction[:]
+ found := make([]byte, len(expected))
+ _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ if err == nil && bytes.Equal(expected, found) {
+ // Skip the cpuid instruction.
+ t.Arch().CPUIDEmulate(t)
+ t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+ region.End()
+
+ return nil
+ }
+ region.End() // Not an actual CPUID, but required copy-in.
+ return platform.ErrContextSignal
+}
+
// The runApp state checks for interrupts before executing untrusted
// application code.
//
// +stateify savable
type runApp struct{}
-func (*runApp) execute(t *Task) taskRunState {
+func (app *runApp) execute(t *Task) taskRunState {
if t.interrupted() {
// Checkpointing instructs tasks to stop by sending an interrupt, so we
// must check for stops before entering runInterrupt (instead of
@@ -237,21 +263,10 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextSignalCPUID:
- // Is this a CPUID instruction?
- region := trace.StartRegion(t.traceContext, cpuidRegion)
- expected := arch.CPUIDInstruction[:]
- found := make([]byte, len(expected))
- _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
- if err == nil && bytes.Equal(expected, found) {
- // Skip the cpuid instruction.
- t.Arch().CPUIDEmulate(t)
- t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
- region.End()
-
+ if err := app.handleCPUIDInstruction(t); err == nil {
// Resume execution.
return (*runApp)(nil)
}
- region.End() // Not an actual CPUID, but required copy-in.
// The instruction at the given RIP was not a CPUID, and we
// fallthrough to the default signal deliver behavior below.
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 4f2406ce3..581841555 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -80,7 +80,7 @@ go_library(
"pagetables_amd64.go",
"pagetables_arm64.go",
"pagetables_x86.go",
- "pcids_x86.go",
+ "pcids.go",
"walker_amd64.go",
"walker_arm64.go",
"walker_empty.go",
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids.go
index e199bae18..9206030bf 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
-
package pagetables
import (
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 48c268bfa..13a9a60b4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -712,6 +712,10 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
+ }
+
family := usermem.ByteOrder.Uint16(sockaddr)
var addr tcpip.FullAddress
@@ -2663,7 +2667,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
// Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
v += len(s.readView)
+ s.readMu.Unlock()
if v > math.MaxInt32 {
v = math.MaxInt32
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index bfb2fac26..f7d6009a0 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -221,7 +221,7 @@ func (w *Watchdog) waitForStart() {
return
}
var buf bytes.Buffer
- buf.WriteString("Watchdog.Start() not called within %s:\n")
+ buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
w.doAction(w.StartupTimeoutAction, false, &buf)
}
@@ -325,7 +325,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
- buf.WriteString("Watchdog goroutine is stuck:\n")
+ buf.WriteString("Watchdog goroutine is stuck:")
w.doAction(w.TaskTimeoutAction, false, &buf)
}
@@ -359,7 +359,7 @@ func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) {
case <-metricsEmitted:
case <-time.After(1 * time.Second):
}
- panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String()))
+ panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String()))
default:
panic(fmt.Sprintf("Unknown watchdog action %v", action))
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 46d3a6646..3e6196aee 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -451,7 +451,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
for _, r := range primaryAddrs {
// If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoing() {
+ if !r.isValidForOutgoingRLocked() {
continue
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index ebb6c5e3b..13354d884 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -551,11 +551,13 @@ type TransportEndpointInfo struct {
RegisterNICID tcpip.NICID
}
-// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address
-// and returns the network protocol number to be used to communicate with the
-// specified address. It returns an error if the passed address is incompatible
-// with the receiver.
-func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6
+// address and returns the network protocol number to be used to communicate
+// with the specified address. It returns an error if the passed address is
+// incompatible with the receiver.
+//
+// Preconditon: the parent endpoint mu must be held while calling this method.
+func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 426da1ee6..2a396e9bc 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -291,15 +291,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- toCopy := *to
- to = &toCopy
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- // Find the enpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
+ // Find the endpoint.
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -480,13 +478,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
})
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -517,7 +516,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -630,7 +629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index cd247f3e1..ae4f3f3a9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.state = handshakeSynRcvd
h.ep.mu.Lock()
ttl := h.ep.ttl
+ amss := h.ep.amss
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
@@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
if ttl == 0 {
ttl = s.route.DefaultTTL()
@@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ h.ep.mu.RLock()
+ amss := h.ep.amss
+ h.ep.mu.RUnlock()
+
h.resetState()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
+ h.ep.mu.Lock()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
@@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
+ h.ep.mu.Unlock()
// Execute is also called in a listen context so we want to make sure we
// only send the TS/SACK option when we received the TS/SACK in the
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 9e72730bd..40cc664c0 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int {
// ModerateRecvBuf adjusts the receive buffer and the advertised window
// based on the number of bytes copied to user space.
func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.mu.RLock()
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
return
}
now := time.Now()
if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
e.rcvAutoParams.copied += copied
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
return
}
prevRTTCopied := e.rcvAutoParams.copied + copied
@@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvBufSize = rcvWnd
availAfter := e.receiveBufferAvailableLocked()
mask := uint32(notifyReceiveWindowChanged)
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.notifyProtocolGoroutine(mask)
@@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvAutoParams.measureTime = now
e.rcvAutoParams.copied = 0
e.rcvListMu.Unlock()
+ e.mu.RUnlock()
}
// IPTables implements tcpip.Endpoint.IPTables.
@@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
v, err := e.readLocked()
e.rcvListMu.Unlock()
-
e.mu.RUnlock()
if err == tcpip.ErrClosedForReceive {
@@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// windowCrossedACKThreshold checks if the receive window to be announced now
-// would be under aMSS or under half receive buffer, whichever smaller. This is
-// useful as a receive side silly window syndrome prevention mechanism. If
+// windowCrossedACKThresholdLocked checks if the receive window to be announced
+// now would be under aMSS or under half receive buffer, whichever smaller. This
+// is useful as a receive side silly window syndrome prevention mechanism. If
// window grows to reasonable value, we should send ACK to the sender to inform
// the rx space is now large. We also want ensure a series of small read()'s
// won't trigger a flood of spurious tiny ACK's.
@@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// crossed will be true if the window size crossed the ACK threshold.
// above will be true if the new window is >= ACK threshold and false
// otherwise.
-func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
+//
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
newAvail := e.receiveBufferAvailableLocked()
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
@@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
mask := uint32(notifyReceiveWindowChanged)
+ e.mu.RLock()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Immediately send an ACK to uncork the sender silly window
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
-
+ e.mu.RUnlock()
e.notifyProtocolGoroutine(mask)
return nil
@@ -1868,13 +1874,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -1904,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
connectingAddr := addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2270,7 +2277,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
e.BindAddr = addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -2414,13 +2421,14 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
func (e *endpoint) readyToRead(s *segment) {
+ e.mu.RLock()
e.rcvListMu.Lock()
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
// Increase counter if the receive window falls down below MSS
// or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
e.rcvList.PushBack(s)
@@ -2428,7 +2436,7 @@ func (e *endpoint) readyToRead(s *segment) {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
-
+ e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventIn)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1c6a600b8..0af4514e1 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -443,19 +443,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- netProto, err := e.checkV4Mapped(to)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- r, _, err := e.connectRoute(nicID, *to, netProto)
+ r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
return 0, nil, err
}
defer r.Release()
route = &r
- dstPort = to.Port
+ dstPort = dst.Port
}
if route.IsResolutionRequired() {
@@ -566,7 +566,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa)
+ fa, netProto, err := e.checkV4MappedLocked(fa)
if err != nil {
return err
}
@@ -927,13 +927,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
if err != nil {
- return 0, err
+ return tcpip.FullAddress{}, 0, err
}
- *addr = unwrapped
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -981,10 +982,6 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr)
- if err != nil {
- return err
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -1012,6 +1009,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
r, nicID, err := e.connectRoute(nicID, addr, netProto)
if err != nil {
return err
@@ -1139,7 +1141,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 43fb047ed..466bd9381 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
e.stack = s
for _, m := range e.multicastMemberships {