summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Makefile9
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go82
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go7
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go17
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go9
-rw-r--r--pkg/sentry/kernel/pipe/BUILD2
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go198
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go99
-rw-r--r--pkg/sentry/kernel/pipe/save_restore.go26
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go202
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go50
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go15
-rw-r--r--pkg/syserr/netstack.go1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go29
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/network/arp/arp.go27
-rw-r--r--pkg/tcpip/network/arp/arp_test.go278
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go16
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go33
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go13
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go10
-rw-r--r--pkg/tcpip/stack/nic.go10
-rw-r--r--pkg/tcpip/stack/packet_buffer.go11
-rw-r--r--pkg/tcpip/stack/pending_packets.go8
-rw-r--r--pkg/tcpip/stack/registration.go21
-rw-r--r--pkg/tcpip/stack/stack_test.go52
-rw-r--r--pkg/tcpip/stack/transport_test.go14
-rw-r--r--pkg/tcpip/tcpip.go68
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go471
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go279
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go4
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go4
-rw-r--r--pkg/tcpip/tests/integration/route_test.go14
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go37
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go4
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go52
-rw-r--r--pkg/tcpip/transport/tcp/accept.go7
-rw-r--r--pkg/tcpip/transport/tcp/connect.go65
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go70
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go1
-rw-r--r--pkg/tcpip/transport/tcp/segment.go23
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go12
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go45
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go1
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go20
-rw-r--r--pkg/test/testutil/testutil.go38
-rw-r--r--runsc/cmd/gofer_test.go7
-rw-r--r--runsc/config/flags.go2
-rw-r--r--test/root/crictl_test.go4
-rw-r--r--test/syscalls/linux/mkdir.cc33
-rw-r--r--test/syscalls/linux/splice.cc106
-rw-r--r--website/cmd/server/main.go2
60 files changed, 1617 insertions, 1096 deletions
diff --git a/Makefile b/Makefile
index 284491c4f..8a1c4321c 100644
--- a/Makefile
+++ b/Makefile
@@ -323,6 +323,7 @@ containerd-tests: containerd-test-1.4.3
## BENCHMARKS_PLATFORMS - platforms to run benchmarks (e.g. ptrace kvm).
## BENCHMARKS_FILTER - filter to be applied to the test suite.
## BENCHMARKS_OPTIONS - options to be passed to the test.
+## BENCHMARKS_PROFILE - profile options to be passed to the test.
##
BENCHMARKS_PROJECT ?= gvisor-benchmarks
BENCHMARKS_DATASET ?= kokoro
@@ -334,7 +335,8 @@ BENCHMARKS_PLATFORMS ?= ptrace
BENCHMARKS_TARGETS := //test/benchmarks/media:ffmpeg_test
BENCHMARKS_FILTER := .
BENCHMARKS_OPTIONS := -test.benchtime=30s
-BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex $(BENCHMARKS_OPTIONS)
+BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS)
+BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex
init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema.
@$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE))
@@ -344,9 +346,10 @@ init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema.
run_benchmark = \
($(call header,BENCHMARK $(1) $(2)); \
set -euo pipefail; \
- if test "$(1)" != "runc"; then $(call install_runtime,$(1),--profile $(2)); fi; \
export T=$$(mktemp --tmpdir logs.$(1).XXXXXX); \
- $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS)) | tee $$T; \
+ if test "$(1)" = "runc"; then $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS)) | tee $$T; fi; \
+ if test "$(1)" != "runc"; then $(call install_runtime,$(1),--profile $(2)); \
+ $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS) $(BENCHMARKS_PROFILE)) | tee $$T; fi; \
if test "$(BENCHMARKS_UPLOAD)" = "true"; then \
$(call run,tools/parsers:parser,parse --debug --file=$$T --runtime=$(1) --suite_name=$(BENCHMARKS_SUITE) --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)); \
fi; \
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index df27554d3..91d5dc174 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -407,33 +407,44 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+
+ // Order of checks is important. First check if parent directory can be
+ // executed, then check for existence, and lastly check if mount is writable.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return err
}
name := rp.Component()
if name == "." || name == ".." {
return syserror.EEXIST
}
- if len(name) > maxFilenameLen {
- return syserror.ENAMETOOLONG
- }
if parent.isDeleted() {
return syserror.ENOENT
}
+
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds)
+ switch {
+ case err != nil && err != syserror.ENOENT:
+ return err
+ case child != nil:
+ return syserror.EEXIST
+ }
+
mnt := rp.Mount()
if err := mnt.CheckBeginWrite(); err != nil {
return err
}
defer mnt.EndWrite()
- parent.dirMu.Lock()
- defer parent.dirMu.Unlock()
+
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return err
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if parent.isSynthetic() {
- if child := parent.children[name]; child != nil {
- return syserror.EEXIST
- }
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
if createInSyntheticDir == nil {
return syserror.EPERM
}
@@ -449,47 +460,20 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
return nil
}
- if fs.opts.interop == InteropModeShared {
- if child := parent.children[name]; child != nil && child.isSynthetic() {
- return syserror.EEXIST
- }
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
- // The existence of a non-synthetic dentry at name would be inconclusive
- // because the file it represents may have been deleted from the remote
- // filesystem, so we would need to make an RPC to revalidate the dentry.
- // Just attempt the file creation RPC instead. If a file does exist, the
- // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a
- // stale dentry exists, the dentry will fail revalidation next time it's
- // used.
- if err := createInRemoteDir(parent, name, &ds); err != nil {
- return err
- }
- ev := linux.IN_CREATE
- if dir {
- ev |= linux.IN_ISDIR
- }
- parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
- return nil
- }
- if child := parent.children[name]; child != nil {
- return syserror.EEXIST
- }
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
- // No cached dentry exists; however, there might still be an existing file
- // at name. As above, we attempt the file creation RPC anyway.
+ // No cached dentry exists; however, in InteropModeShared there might still be
+ // an existing file at name. Just attempt the file creation RPC anyways. If a
+ // file does exist, the RPC will fail with EEXIST like we would have.
if err := createInRemoteDir(parent, name, &ds); err != nil {
return err
}
- if child, ok := parent.children[name]; ok && child == nil {
- // Delete the now-stale negative dentry.
- delete(parent.children, name)
+ if fs.opts.interop != InteropModeShared {
+ if child, ok := parent.children[name]; ok && child == nil {
+ // Delete the now-stale negative dentry.
+ delete(parent.children, name)
+ }
+ parent.touchCMtime()
+ parent.dirents = nil
}
- parent.touchCMtime()
- parent.dirents = nil
ev := linux.IN_CREATE
if dir {
ev |= linux.IN_ISDIR
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index e77523f22..a7a553619 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -208,7 +208,9 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
// * Filesystem.mu must be locked for at least reading.
// * isDir(parentInode) == true.
func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error {
- if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ // Order of checks is important. First check if parent directory can be
+ // executed, then check for existence, and lastly check if mount is writable.
+ if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayExec); err != nil {
return err
}
if name == "." || name == ".." {
@@ -223,6 +225,9 @@ func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string
if parent.VFSDentry().IsDead() {
return syserror.ENOENT
}
+ if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite); err != nil {
+ return err
+ }
return nil
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index d55bdc97f..e46f593c7 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -480,9 +480,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err != nil {
return err
}
- if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
- return err
- }
name := rp.Component()
if name == "." || name == ".." {
return syserror.EEXIST
@@ -490,11 +487,11 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if parent.vfsd.IsDead() {
return syserror.ENOENT
}
- mnt := rp.Mount()
- if err := mnt.CheckBeginWrite(); err != nil {
+
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return err
}
- defer mnt.EndWrite()
+
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
@@ -514,6 +511,14 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return syserror.ENOENT
}
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
// Ensure that the parent directory is copied-up so that we can create the
// new file in the upper layer.
if err := parent.copyUpLocked(ctx); err != nil {
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 9296db2fb..453e41d11 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -153,7 +153,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if err != nil {
return err
}
- if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+
+ // Order of checks is important. First check if parent directory can be
+ // executed, then check for existence, and lastly check if mount is writable.
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
return err
}
name := rp.Component()
@@ -179,6 +182,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return err
}
defer mnt.EndWrite()
+
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return err
+ }
if err := create(parentDir, name); err != nil {
return err
}
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 99134e634..2c32d017d 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -12,6 +12,7 @@ go_library(
"pipe_util.go",
"reader.go",
"reader_writer.go",
+ "save_restore.go",
"vfs.go",
"writer.go",
],
@@ -19,7 +20,6 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
- "//pkg/buffer",
"//pkg/context",
"//pkg/marshal/primitive",
"//pkg/safemem",
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index b989e14c7..c551acd99 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -21,8 +21,8 @@ import (
"sync/atomic"
"syscall"
- "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -75,10 +75,18 @@ type Pipe struct {
// mu protects all pipe internal state below.
mu sync.Mutex `state:"nosave"`
- // view is the underlying set of buffers.
+ // buf holds the pipe's data. buf is a circular buffer; the first valid
+ // byte in buf is at offset off, and the pipe contains size valid bytes.
+ // bufBlocks contains two identical safemem.Blocks representing buf; this
+ // avoids needing to heap-allocate a new safemem.Block slice when buf is
+ // resized. bufBlockSeq is a safemem.BlockSeq representing bufBlocks.
//
- // This is protected by mu.
- view buffer.View
+ // These fields are protected by mu.
+ buf []byte
+ bufBlocks [2]safemem.Block `state:"nosave"`
+ bufBlockSeq safemem.BlockSeq `state:"nosave"`
+ off int64
+ size int64
// max is the maximum size of the pipe in bytes. When this max has been
// reached, writers will get EWOULDBLOCK.
@@ -99,12 +107,6 @@ type Pipe struct {
//
// N.B. The size will be bounded.
func NewPipe(isNamed bool, sizeBytes int64) *Pipe {
- if sizeBytes < MinimumPipeSize {
- sizeBytes = MinimumPipeSize
- }
- if sizeBytes > MaximumPipeSize {
- sizeBytes = MaximumPipeSize
- }
var p Pipe
initPipe(&p, isNamed, sizeBytes)
return &p
@@ -175,75 +177,71 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F
}
}
-type readOps struct {
- // left returns the bytes remaining.
- left func() int64
-
- // limit limits subsequence reads.
- limit func(int64)
-
- // read performs the actual read operation.
- read func(*buffer.View) (int64, error)
-}
-
-// read reads data from the pipe into dst and returns the number of bytes
-// read, or returns ErrWouldBlock if the pipe is empty.
+// peekLocked passes the first count bytes in the pipe to f and returns its
+// result. If fewer than count bytes are available, the safemem.BlockSeq passed
+// to f will be less than count bytes in length.
//
-// Precondition: this pipe must have readers.
-func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
- return p.readLocked(ctx, ops)
-}
-
-func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) {
+// peekLocked does not mutate the pipe; if the read consumes bytes from the
+// pipe, then the caller is responsible for calling p.consumeLocked() and
+// p.Notify(waiter.EventOut). (The latter must be called with p.mu unlocked.)
+//
+// Preconditions:
+// * p.mu must be locked.
+// * This pipe must have readers.
+func (p *Pipe) peekLocked(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
// Don't block for a zero-length read even if the pipe is empty.
- if ops.left() == 0 {
+ if count == 0 {
return 0, nil
}
- // Is the pipe empty?
- if p.view.Size() == 0 {
- if !p.HasWriters() {
- // There are no writers, return EOF.
- return 0, io.EOF
+ // Limit the amount of data read to the amount of data in the pipe.
+ if count > p.size {
+ if p.size == 0 {
+ if !p.HasWriters() {
+ return 0, io.EOF
+ }
+ return 0, syserror.ErrWouldBlock
}
- return 0, syserror.ErrWouldBlock
+ count = p.size
}
- // Limit how much we consume.
- if ops.left() > p.view.Size() {
- ops.limit(p.view.Size())
- }
+ // Prepare the view of the data to be read.
+ bs := p.bufBlockSeq.DropFirst64(uint64(p.off)).TakeFirst64(uint64(count))
- // Copy user data; the read op is responsible for trimming.
- done, err := ops.read(&p.view)
- return done, err
+ // Perform the read.
+ done, err := f(bs)
+ return int64(done), err
}
-type writeOps struct {
- // left returns the bytes remaining.
- left func() int64
-
- // limit should limit subsequent writes.
- limit func(int64)
-
- // write should write to the provided buffer.
- write func(*buffer.View) (int64, error)
-}
-
-// write writes data from sv into the pipe and returns the number of bytes
-// written. If no bytes are written because the pipe is full (or has less than
-// atomicIOBytes free capacity), write returns ErrWouldBlock.
+// consumeLocked consumes the first n bytes in the pipe, such that they will no
+// longer be visible to future reads.
//
-// Precondition: this pipe must have writers.
-func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
- return p.writeLocked(ctx, ops)
+// Preconditions:
+// * p.mu must be locked.
+// * The pipe must contain at least n bytes.
+func (p *Pipe) consumeLocked(n int64) {
+ p.off += n
+ if max := int64(len(p.buf)); p.off >= max {
+ p.off -= max
+ }
+ p.size -= n
}
-func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
+// writeLocked passes a safemem.BlockSeq representing the first count bytes of
+// unused space in the pipe to f and returns the result. If fewer than count
+// bytes are free, the safemem.BlockSeq passed to f will be less than count
+// bytes in length. If the pipe is full or otherwise cannot accomodate a write
+// of any number of bytes up to count, writeLocked returns ErrWouldBlock
+// without calling f.
+//
+// Unlike peekLocked, writeLocked assumes that f returns the number of bytes
+// written to the pipe, and increases the number of bytes stored in the pipe
+// accordingly. Callers are still responsible for calling
+// p.Notify(waiter.EventIn) with p.mu unlocked.
+//
+// Preconditions:
+// * p.mu must be locked.
+func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
// Can't write to a pipe with no readers.
if !p.HasReaders() {
return 0, syscall.EPIPE
@@ -251,29 +249,59 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
- wanted := ops.left()
- avail := p.max - p.view.Size()
- if wanted > avail {
- if wanted <= atomicIOBytes {
+ avail := p.max - p.size
+ short := false
+ if count > avail {
+ if count <= atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
- ops.limit(avail)
+ count = avail
+ short = true
}
- // Copy user data.
- done, err := ops.write(&p.view)
- if err != nil {
- return done, err
+ // Ensure that the buffer is big enough.
+ if newLen, oldCap := p.size+count, int64(len(p.buf)); newLen > oldCap {
+ // Allocate a new buffer.
+ newCap := oldCap * 2
+ if oldCap == 0 {
+ newCap = 8 // arbitrary; sending individual integers across pipes is relatively common
+ }
+ for newLen > newCap {
+ newCap *= 2
+ }
+ if newCap > p.max {
+ newCap = p.max
+ }
+ newBuf := make([]byte, newCap)
+ // Copy the old buffer's contents to the beginning of the new one.
+ safemem.CopySeq(
+ safemem.BlockSeqOf(safemem.BlockFromSafeSlice(newBuf)),
+ p.bufBlockSeq.DropFirst64(uint64(p.off)).TakeFirst64(uint64(p.size)))
+ // Switch to the new buffer.
+ p.buf = newBuf
+ p.bufBlocks[0] = safemem.BlockFromSafeSlice(newBuf)
+ p.bufBlocks[1] = p.bufBlocks[0]
+ p.bufBlockSeq = safemem.BlockSeqFromSlice(p.bufBlocks[:])
+ p.off = 0
}
- if done < avail {
- // Non-failure, but short write.
- return done, nil
+ // Prepare the view of the space to be written.
+ woff := p.off + p.size
+ if woff >= int64(len(p.buf)) {
+ woff -= int64(len(p.buf))
}
- if done < wanted {
- // Partial write due to full pipe. Note that this could also be
- // the short write case above, we would expect a second call
- // and the write to return zero bytes in this case.
+ bs := p.bufBlockSeq.DropFirst64(uint64(woff)).TakeFirst64(uint64(count))
+
+ // Perform the write.
+ doneU64, err := f(bs)
+ done := int64(doneU64)
+ p.size += done
+ if done < count || err != nil {
+ return done, err
+ }
+
+ // If we shortened the write, adjust the returned error appropriately.
+ if short {
return done, syserror.ErrWouldBlock
}
@@ -324,7 +352,7 @@ func (p *Pipe) HasWriters() bool {
// Precondition: mu must be held.
func (p *Pipe) rReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasReaders() && p.view.Size() != 0 {
+ if p.HasReaders() && p.size != 0 {
ready |= waiter.EventIn
}
if !p.HasWriters() && p.hadWriter {
@@ -350,7 +378,7 @@ func (p *Pipe) rReadiness() waiter.EventMask {
// Precondition: mu must be held.
func (p *Pipe) wReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasWriters() && p.view.Size() < p.max {
+ if p.HasWriters() && p.size < p.max {
ready |= waiter.EventOut
}
if !p.HasReaders() {
@@ -383,7 +411,7 @@ func (p *Pipe) queued() int64 {
}
func (p *Pipe) queuedLocked() int64 {
- return p.view.Size()
+ return p.size
}
// FifoSize implements fs.FifoSizer.FifoSize.
@@ -406,7 +434,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) {
}
p.mu.Lock()
defer p.mu.Unlock()
- if size < p.view.Size() {
+ if size < p.size {
return 0, syserror.EBUSY
}
p.max = size
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index f665920cb..77246edbe 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -21,9 +21,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
- "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -44,46 +44,37 @@ func (p *Pipe) Release(context.Context) {
// Read reads from the Pipe into dst.
func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
- n, err := p.read(ctx, readOps{
- left: func() int64 {
- return dst.NumBytes()
- },
- limit: func(l int64) {
- dst = dst.TakeFirst64(l)
- },
- read: func(view *buffer.View) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, view)
- dst = dst.DropFirst64(n)
- view.TrimFront(n)
- return n, err
- },
- })
+ n, err := dst.CopyOutFrom(ctx, p)
if n > 0 {
p.Notify(waiter.EventOut)
}
return n, err
}
+// ReadToBlocks implements safemem.Reader.ReadToBlocks for Pipe.Read.
+func (p *Pipe) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ n, err := p.read(int64(dsts.NumBytes()), func(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.CopySeq(dsts, srcs)
+ }, true /* removeFromSrc */)
+ return uint64(n), err
+}
+
+func (p *Pipe) read(count int64, f func(srcs safemem.BlockSeq) (uint64, error), removeFromSrc bool) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ n, err := p.peekLocked(count, f)
+ if n > 0 && removeFromSrc {
+ p.consumeLocked(n)
+ }
+ return n, err
+}
+
// WriteTo writes to w from the Pipe.
func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) {
- ops := readOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- read: func(view *buffer.View) (int64, error) {
- n, err := view.ReadToWriter(w, count)
- if !dup {
- view.TrimFront(n)
- }
- count -= n
- return n, err
- },
- }
- n, err := p.read(ctx, ops)
- if n > 0 {
+ n, err := p.read(count, func(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.FromIOWriter{w}.WriteFromBlocks(srcs)
+ }, !dup /* removeFromSrc */)
+ if n > 0 && !dup {
p.Notify(waiter.EventOut)
}
return n, err
@@ -91,39 +82,31 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool)
// Write writes to the Pipe from src.
func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) {
- n, err := p.write(ctx, writeOps{
- left: func() int64 {
- return src.NumBytes()
- },
- limit: func(l int64) {
- src = src.TakeFirst64(l)
- },
- write: func(view *buffer.View) (int64, error) {
- n, err := src.CopyInTo(ctx, view)
- src = src.DropFirst64(n)
- return n, err
- },
- })
+ n, err := src.CopyInTo(ctx, p)
if n > 0 {
p.Notify(waiter.EventIn)
}
return n, err
}
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks for Pipe.Write.
+func (p *Pipe) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ n, err := p.write(int64(srcs.NumBytes()), func(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.CopySeq(dsts, srcs)
+ })
+ return uint64(n), err
+}
+
+func (p *Pipe) write(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.writeLocked(count, f)
+}
+
// ReadFrom reads from r to the Pipe.
func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) {
- n, err := p.write(ctx, writeOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- write: func(view *buffer.View) (int64, error) {
- n, err := view.WriteFromReader(r, count)
- count -= n
- return n, err
- },
+ n, err := p.write(count, func(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.FromIOReader{r}.ReadToBlocks(dsts)
})
if n > 0 {
p.Notify(waiter.EventIn)
diff --git a/pkg/sentry/kernel/pipe/save_restore.go b/pkg/sentry/kernel/pipe/save_restore.go
new file mode 100644
index 000000000..f135827de
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/save_restore.go
@@ -0,0 +1,26 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// afterLoad is called by stateify.
+func (p *Pipe) afterLoad() {
+ p.bufBlocks[0] = safemem.BlockFromSafeSlice(p.buf)
+ p.bufBlocks[1] = p.bufBlocks[0]
+ p.bufBlockSeq = safemem.BlockSeqFromSlice(p.bufBlocks[:])
+}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 2d47d2e82..d5a91730d 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -16,7 +16,6 @@ package pipe
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -269,12 +268,10 @@ func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) {
// SpliceToNonPipe performs a splice operation from fd to a non-pipe file.
func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescription, off, count int64) (int64, error) {
fd.pipe.mu.Lock()
- defer fd.pipe.mu.Unlock()
// Cap the sequence at number of bytes actually available.
- v := fd.pipe.queuedLocked()
- if v < count {
- count = v
+ if count > fd.pipe.size {
+ count = fd.pipe.size
}
src := usermem.IOSequence{
IO: fd,
@@ -291,154 +288,97 @@ func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescripti
n, err = out.PWrite(ctx, src, off, vfs.WriteOptions{})
}
if n > 0 {
- fd.pipe.view.TrimFront(n)
+ fd.pipe.consumeLocked(n)
+ }
+
+ fd.pipe.mu.Unlock()
+
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
}
return n, err
}
// SpliceFromNonPipe performs a splice operation from a non-pipe file to fd.
func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescription, off, count int64) (int64, error) {
- fd.pipe.mu.Lock()
- defer fd.pipe.mu.Unlock()
-
dst := usermem.IOSequence{
IO: fd,
Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}),
}
+ var (
+ n int64
+ err error
+ )
+ fd.pipe.mu.Lock()
if off == -1 {
- return in.Read(ctx, dst, vfs.ReadOptions{})
+ n, err = in.Read(ctx, dst, vfs.ReadOptions{})
+ } else {
+ n, err = in.PRead(ctx, dst, off, vfs.ReadOptions{})
+ }
+ fd.pipe.mu.Unlock()
+
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
}
- return in.PRead(ctx, dst, off, vfs.ReadOptions{})
+ return n, err
}
// CopyIn implements usermem.IO.CopyIn. Note that it is the caller's
-// responsibility to trim fd.pipe.view after the read is completed.
+// responsibility to call fd.pipe.consumeLocked() and
+// fd.pipe.Notify(waiter.EventOut) after the read is completed.
+//
+// Preconditions: fd.pipe.mu must be locked.
func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
- origCount := int64(len(dst))
- n, err := fd.pipe.readLocked(ctx, readOps{
- left: func() int64 {
- return int64(len(dst))
- },
- limit: func(l int64) {
- dst = dst[:l]
- },
- read: func(view *buffer.View) (int64, error) {
- n, err := view.ReadAt(dst, 0)
- return int64(n), err
- },
+ n, err := fd.pipe.peekLocked(int64(len(dst)), func(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), srcs)
})
- if n > 0 {
- fd.pipe.Notify(waiter.EventOut)
- }
- if err == nil && n != origCount {
- return int(n), syserror.ErrWouldBlock
- }
return int(n), err
}
-// CopyOut implements usermem.IO.CopyOut.
+// CopyOut implements usermem.IO.CopyOut. Note that it is the caller's
+// responsibility to call fd.pipe.Notify(waiter.EventIn) after the
+// write is completed.
+//
+// Preconditions: fd.pipe.mu must be locked.
func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
- origCount := int64(len(src))
- n, err := fd.pipe.writeLocked(ctx, writeOps{
- left: func() int64 {
- return int64(len(src))
- },
- limit: func(l int64) {
- src = src[:l]
- },
- write: func(view *buffer.View) (int64, error) {
- view.Append(src)
- return int64(len(src)), nil
- },
+ n, err := fd.pipe.writeLocked(int64(len(src)), func(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src)))
})
- if n > 0 {
- fd.pipe.Notify(waiter.EventIn)
- }
- if err == nil && n != origCount {
- return int(n), syserror.ErrWouldBlock
- }
return int(n), err
}
// ZeroOut implements usermem.IO.ZeroOut.
+//
+// Preconditions: fd.pipe.mu must be locked.
func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
- origCount := toZero
- n, err := fd.pipe.writeLocked(ctx, writeOps{
- left: func() int64 {
- return toZero
- },
- limit: func(l int64) {
- toZero = l
- },
- write: func(view *buffer.View) (int64, error) {
- view.Grow(view.Size()+toZero, true /* zero */)
- return toZero, nil
- },
+ n, err := fd.pipe.writeLocked(toZero, func(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.ZeroSeq(dsts)
})
- if n > 0 {
- fd.pipe.Notify(waiter.EventIn)
- }
- if err == nil && n != origCount {
- return n, syserror.ErrWouldBlock
- }
return n, err
}
// CopyInTo implements usermem.IO.CopyInTo. Note that it is the caller's
-// responsibility to trim fd.pipe.view after the read is completed.
+// responsibility to call fd.pipe.consumeLocked() and
+// fd.pipe.Notify(waiter.EventOut) after the read is completed.
+//
+// Preconditions: fd.pipe.mu must be locked.
func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
- count := ars.NumBytes()
- if count == 0 {
- return 0, nil
- }
- origCount := count
- n, err := fd.pipe.readLocked(ctx, readOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- read: func(view *buffer.View) (int64, error) {
- n, err := view.ReadToSafememWriter(dst, uint64(count))
- return int64(n), err
- },
+ return fd.pipe.peekLocked(ars.NumBytes(), func(srcs safemem.BlockSeq) (uint64, error) {
+ return dst.WriteFromBlocks(srcs)
})
- if n > 0 {
- fd.pipe.Notify(waiter.EventOut)
- }
- if err == nil && n != origCount {
- return n, syserror.ErrWouldBlock
- }
- return n, err
}
// CopyOutFrom implements usermem.IO.CopyOutFrom.
+//
+// Preconditions: fd.pipe.mu must be locked.
func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
- count := ars.NumBytes()
- if count == 0 {
- return 0, nil
- }
- origCount := count
- n, err := fd.pipe.writeLocked(ctx, writeOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- write: func(view *buffer.View) (int64, error) {
- n, err := view.WriteFromSafememReader(src, uint64(count))
- return int64(n), err
- },
+ n, err := fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) {
+ return src.ReadToBlocks(dsts)
})
if n > 0 {
fd.pipe.Notify(waiter.EventIn)
}
- if err == nil && n != origCount {
- return n, syserror.ErrWouldBlock
- }
return n, err
}
@@ -481,37 +421,23 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr
}
lockTwoPipes(dst.pipe, src.pipe)
- defer dst.pipe.mu.Unlock()
- defer src.pipe.mu.Unlock()
-
- n, err := dst.pipe.writeLocked(ctx, writeOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- write: func(dstView *buffer.View) (int64, error) {
- return src.pipe.readLocked(ctx, readOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- read: func(srcView *buffer.View) (int64, error) {
- n, err := srcView.ReadToSafememWriter(dstView, uint64(count))
- if n > 0 && removeFromSrc {
- srcView.TrimFront(int64(n))
- }
- return int64(n), err
- },
- })
- },
+ n, err := dst.pipe.writeLocked(count, func(dsts safemem.BlockSeq) (uint64, error) {
+ n, err := src.pipe.peekLocked(int64(dsts.NumBytes()), func(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.CopySeq(dsts, srcs)
+ })
+ if n > 0 && removeFromSrc {
+ src.pipe.consumeLocked(n)
+ }
+ return uint64(n), err
})
+ dst.pipe.mu.Unlock()
+ src.pipe.mu.Unlock()
+
if n > 0 {
dst.pipe.Notify(waiter.EventIn)
- src.pipe.Notify(waiter.EventOut)
+ if removeFromSrc {
+ src.pipe.Notify(waiter.EventOut)
+ }
}
return n, err
}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index b2206900b..22abca120 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -18,7 +18,6 @@ go_library(
],
deps = [
"//pkg/abi/linux",
- "//pkg/amutex",
"//pkg/binary",
"//pkg/context",
"//pkg/log",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 57f224120..03749a8bf 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -36,7 +36,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
@@ -187,6 +186,21 @@ var Metrics = tcpip.Stats{
IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."),
IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."),
},
+ ARP: tcpip.ARPStats{
+ PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."),
+ DisabledPacketsReceived: mustCreateMetric("/netstack/arp/disabled_packets_received", "Number of ARP packets received from the link layer when the ARP layer is disabled."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/arp/malformed_packets_received", "Number of ARP packets which failed ARP header validation checks."),
+ RequestsReceived: mustCreateMetric("/netstack/arp/requests_received", "Number of ARP requests received."),
+ RequestsReceivedUnknownTargetAddress: mustCreateMetric("/netstack/arp/requests_received_unknown_addr", "Number of ARP requests received with an unknown target address."),
+ OutgoingRequestInterfaceHasNoLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_iface_has_no_addr", "Number of failed attempts to send an ARP request with an interface that has no network address."),
+ OutgoingRequestBadLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_invalid_local_addr", "Number of failed attempts to send an ARP request with a provided local address that is invalid."),
+ OutgoingRequestNetworkUnreachableErrors: mustCreateMetric("/netstack/arp/outgoing_requests_network_unreachable", "Number of failed attempts to send an ARP request with a network unreachable error."),
+ OutgoingRequestsDropped: mustCreateMetric("/netstack/arp/outgoing_requests_dropped", "Number of ARP requests which failed to write to a link-layer endpoint."),
+ OutgoingRequestsSent: mustCreateMetric("/netstack/arp/outgoing_requests_sent", "Number of ARP requests sent."),
+ RepliesReceived: mustCreateMetric("/netstack/arp/replies_received", "Number of ARP replies received."),
+ OutgoingRepliesDropped: mustCreateMetric("/netstack/arp/outgoing_replies_dropped", "Number of ARP replies which failed to write to a link-layer endpoint."),
+ OutgoingRepliesSent: mustCreateMetric("/netstack/arp/outgoing_replies_sent", "Number of ARP replies sent."),
+ },
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
@@ -459,18 +473,10 @@ func (i *ioSequencePayload) DropFirst(n int) {
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
f := &ioSequencePayload{ctx: ctx, src: src}
- n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
return 0, syserror.ErrWouldBlock
}
-
- if resCh != nil {
- if err := amutex.Block(ctx, resCh); err != nil {
- return 0, err
- }
- n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
- }
-
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
@@ -526,24 +532,12 @@ func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
// ReadFrom implements fs.FileOperations.ReadFrom.
func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
f := &readerPayload{ctx: ctx, r: r, count: count}
- n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ n, err := s.Endpoint.Write(f, tcpip.WriteOptions{
// Reads may be destructive but should be very fast,
// so we can't release the lock while copying data.
Atomic: true,
})
if err == tcpip.ErrWouldBlock {
- return 0, syserror.ErrWouldBlock
- }
-
- if resCh != nil {
- if err := amutex.Block(ctx, resCh); err != nil {
- return 0, err
- }
- n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
- Atomic: true, // See above.
- })
- }
- if err == tcpip.ErrWouldBlock {
return n, syserror.ErrWouldBlock
} else if err != nil {
return int64(n), f.err // Propagate error.
@@ -2836,13 +2830,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
v := &ioSequencePayload{t, src}
- n, resCh, err := s.Endpoint.Write(v, opts)
- if resCh != nil {
- if err := t.Block(resCh); err != nil {
- return 0, syserr.FromError(err)
- }
- n, _, err = s.Endpoint.Write(v, opts)
- }
+ n, err := s.Endpoint.Write(v, opts)
dontWait := flags&linux.MSG_DONTWAIT != 0
if err == nil && (n >= v.src.NumBytes() || dontWait) {
// Complete write.
@@ -2861,7 +2849,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
v.DropFirst(int(n))
total := n
for {
- n, _, err = s.Endpoint.Write(v, opts)
+ n, err = s.Endpoint.Write(v, opts)
v.DropFirst(int(n))
total += n
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index b756bfca0..6f70b02fc 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -16,7 +16,6 @@ package netstack
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
@@ -131,18 +130,10 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
}
f := &ioSequencePayload{ctx: ctx, src: src}
- n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
return 0, syserror.ErrWouldBlock
}
-
- if resCh != nil {
- if err := amutex.Block(ctx, resCh); err != nil {
- return 0, err
- }
- n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
- }
-
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 1c4cdb0dd..134051124 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -29,24 +29,23 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) {
return 0, syserror.EINVAL
}
-
+ if opts.Length == 0 {
+ return 0, nil
+ }
if opts.Length > int64(kernel.MAX_RW_COUNT) {
opts.Length = int64(kernel.MAX_RW_COUNT)
}
var (
- total int64
n int64
err error
inCh chan struct{}
outCh chan struct{}
)
- for opts.Length > 0 {
+ for {
n, err = fs.Splice(t, outFile, inFile, opts)
- opts.Length -= n
- total += n
- if err != syserror.ErrWouldBlock {
+ if n != 0 || err != syserror.ErrWouldBlock {
break
} else if err == syserror.ErrWouldBlock && nonBlocking {
break
@@ -87,13 +86,13 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
}
}
- if total > 0 {
+ if n > 0 {
// On Linux, inotify behavior is not very consistent with splice(2). We try
// our best to emulate Linux for very basic calls to splice, where for some
// reason, events are generated for output files, but not input files.
outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
}
- return total, err
+ return n, err
}
// Sendfile implements linux system call sendfile(2).
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 2756d4471..cb8981633 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -93,7 +93,6 @@ func init() {
addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted)
addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile)
addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue)
- addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown)
addErrMapping(tcpip.ErrBadAddress, ErrBadAddress)
addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable)
addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong)
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 7193f56ad..85a0b8b90 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -397,22 +397,9 @@ func (c *TCPConn) Write(b []byte) (int, error) {
}
var n int64
- var resCh <-chan struct{}
- n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
nbytes += int(n)
v.TrimFront(int(n))
-
- if resCh != nil {
- select {
- case <-deadline:
- return nbytes, c.newOpError("write", &timeoutError{})
- case <-resCh:
- }
-
- n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- nbytes += int(n)
- v.TrimFront(int(n))
- }
}
if err == nil {
@@ -666,17 +653,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
v := buffer.NewView(len(b))
copy(v, b)
- n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
- if resCh != nil {
- select {
- case <-deadline:
- return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
- case <-resCh:
- }
-
- n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
- }
-
+ n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -689,7 +666,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
case <-notifyCh:
}
- n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
if err != tcpip.ErrWouldBlock {
break
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index dd2e1a125..9f5ee43d7 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -191,8 +191,8 @@ func shuffle(b []int) {
}
func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir := os.Getenv("TEST_TMPDIR")
- if tmpDir == "" {
+ tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
+ if !ok {
tmpDir = os.Getenv("TMPDIR")
}
f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 3d5c0d270..3259d052f 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -119,21 +119,28 @@ func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *t
}
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats().ARP
+ stats.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.DisabledPacketsReceived.Increment()
return
}
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
+ stats.MalformedPacketsReceived.Increment()
return
}
switch h.Op() {
case header.ARPRequest:
+ stats.RequestsReceived.Increment()
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.nud == nil {
if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ stats.RequestsReceivedUnknownTargetAddress.Increment()
return // we have no useful answer, ignore the request
}
@@ -142,6 +149,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ stats.RequestsReceivedUnknownTargetAddress.Increment()
return // we have no useful answer, ignore the request
}
@@ -177,9 +185,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
//
// Send the packet to the (new) target hardware address on the same
// hardware on which the request was received.
- _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt)
+ if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil {
+ stats.OutgoingRepliesDropped.Increment()
+ } else {
+ stats.OutgoingRepliesSent.Increment()
+ }
case header.ARPReply:
+ stats.RepliesReceived.Increment()
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
@@ -233,6 +246,8 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ stats := p.stack.Stats().ARP
+
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
@@ -241,15 +256,18 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
if len(localAddr) == 0 {
addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
if err != nil {
+ stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment()
return err
}
if len(addr.Address) == 0 {
+ stats.OutgoingRequestNetworkUnreachableErrors.Increment()
return tcpip.ErrNetworkUnreachable
}
localAddr = addr.Address
} else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ stats.OutgoingRequestBadLocalAddressErrors.Increment()
return tcpip.ErrBadLocalAddress
}
@@ -269,7 +287,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
- return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt)
+ if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ stats.OutgoingRequestsDropped.Increment()
+ return err
+ }
+ stats.OutgoingRequestsSent.Increment()
+ return nil
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index a25cba513..6b61f57ad 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -240,6 +240,10 @@ func TestDirectRequest(t *testing.T) {
for i, address := range []tcpip.Address{stackAddr, remoteAddr} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
+ expectedPacketsReceived := c.s.Stats().ARP.PacketsReceived.Value() + 1
+ expectedRequestsReceived := c.s.Stats().ARP.RequestsReceived.Value() + 1
+ expectedRepliesSent := c.s.Stats().ARP.OutgoingRepliesSent.Value() + 1
+
inject(address)
pi, _ := c.linkEP.ReadContext(context.Background())
if pi.Proto != arp.ProtocolNumber {
@@ -249,6 +253,9 @@ func TestDirectRequest(t *testing.T) {
if !rep.IsValid() {
t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
}
+ if got := rep.Op(); got != header.ARPReply {
+ t.Fatalf("got Op = %d, want = %d", got, header.ARPReply)
+ }
if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
@@ -261,6 +268,16 @@ func TestDirectRequest(t *testing.T) {
if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want)
}
+
+ if got := c.s.Stats().ARP.PacketsReceived.Value(); got != expectedPacketsReceived {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedPacketsReceived)
+ }
+ if got := c.s.Stats().ARP.RequestsReceived.Value(); got != expectedRequestsReceived {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedRequestsReceived)
+ }
+ if got := c.s.Stats().ARP.OutgoingRepliesSent.Value(); got != expectedRepliesSent {
+ t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, expectedRepliesSent)
+ }
})
}
@@ -273,6 +290,84 @@ func TestDirectRequest(t *testing.T) {
if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
+ if got := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(); got != 1 {
+ t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnKnownTargetAddress.Value() = %d, want = 1", got)
+ }
+}
+
+func TestMalformedPacket(t *testing.T) {
+ c := newTestContext(t, false)
+ defer c.cleanup()
+
+ v := make(buffer.View, header.ARPSize)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v.ToVectorisedView(),
+ })
+
+ c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
+
+ if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 {
+ t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got)
+ }
+}
+
+func TestDisabledEndpoint(t *testing.T) {
+ c := newTestContext(t, false)
+ defer c.cleanup()
+
+ ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber)
+ if err != nil {
+ t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err)
+ }
+ ep.Disable()
+
+ v := make(buffer.View, header.ARPSize)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v.ToVectorisedView(),
+ })
+
+ c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
+
+ if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 {
+ t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got)
+ }
+}
+
+func TestDirectReply(t *testing.T) {
+ c := newTestContext(t, false)
+ defer c.cleanup()
+
+ const senderMAC = "\x01\x02\x03\x04\x05\x06"
+ const senderIPv4 = "\x0a\x00\x00\x02"
+
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPReply)
+
+ copy(h.HardwareAddressSender(), senderMAC)
+ copy(h.ProtocolAddressSender(), senderIPv4)
+ copy(h.HardwareAddressTarget(), stackLinkAddr)
+ copy(h.ProtocolAddressTarget(), stackAddr)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v.ToVectorisedView(),
+ })
+
+ c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
+
+ if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
+ }
}
func TestDirectRequestWithNeighborCache(t *testing.T) {
@@ -311,6 +406,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ packetsRecv := c.s.Stats().ARP.PacketsReceived.Value()
+ requestsRecv := c.s.Stats().ARP.RequestsReceived.Value()
+ requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value()
+ outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value()
+
// Inject an incoming ARP request.
v := make(buffer.View, header.ARPSize)
h := header.ARP(v)
@@ -323,6 +423,13 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
Data: v.ToVectorisedView(),
}))
+ if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want {
+ t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want)
+ }
+
if !test.isValid {
// No packets should be sent after receiving an invalid ARP request.
// There is no need to perform a blocking read here, since packets are
@@ -330,9 +437,20 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
if pkt, ok := c.linkEP.Read(); ok {
t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto)
}
+ if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want {
+ t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want)
+ }
+ if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want {
+ t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want)
+ }
+
return
}
+ if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want {
+ t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want)
+ }
+
// Verify an ARP response was sent.
pi, ok := c.linkEP.Read()
if !ok {
@@ -418,6 +536,8 @@ type testInterface struct {
stack.LinkEndpoint
nicID tcpip.NICID
+
+ writeErr *tcpip.Error
}
func (t *testInterface) ID() tcpip.NICID {
@@ -441,6 +561,10 @@ func (*testInterface) Promiscuous() bool {
}
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if t.writeErr != nil {
+ return t.writeErr
+ }
+
var r stack.Route
r.NetProto = protocol
r.ResolveWith(remoteLinkAddr)
@@ -458,61 +582,99 @@ func TestLinkAddressRequest(t *testing.T) {
localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- expectedErr *tcpip.Error
- expectedLocalAddr tcpip.Address
- expectedRemoteLinkAddr tcpip.LinkAddress
+ linkErr *tcpip.Error
+ expectedErr *tcpip.Error
+ expectedLocalAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
+ expectedRequestsSent uint64
+ expectedRequestBadLocalAddressErrors uint64
+ expectedRequestNetworkUnreachableErrors uint64
+ expectedRequestDroppedErrors uint64
}{
{
- name: "Unicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
+ name: "Unicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestNetworkUnreachableErrors: 0,
},
{
- name: "Multicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ name: "Multicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestNetworkUnreachableErrors: 0,
},
{
- name: "Unicast with unspecified source",
- nicAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
+ name: "Unicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestNetworkUnreachableErrors: 0,
},
{
- name: "Multicast with unspecified source",
- nicAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ name: "Multicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestNetworkUnreachableErrors: 0,
},
{
- name: "Unicast with unassigned address",
- localAddr: testAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrBadLocalAddress,
+ name: "Unicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrBadLocalAddress,
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 1,
+ expectedRequestNetworkUnreachableErrors: 0,
},
{
- name: "Multicast with unassigned address",
- localAddr: testAddr,
- remoteLinkAddr: "",
- expectedErr: tcpip.ErrBadLocalAddress,
+ name: "Multicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrBadLocalAddress,
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 1,
+ expectedRequestNetworkUnreachableErrors: 0,
},
{
- name: "Unicast with no local address available",
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrNetworkUnreachable,
+ name: "Unicast with no local address available",
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestNetworkUnreachableErrors: 1,
},
{
- name: "Multicast with no local address available",
- remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
+ name: "Multicast with no local address available",
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestNetworkUnreachableErrors: 1,
+ },
+ {
+ name: "Link error",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ linkErr: tcpip.ErrInvalidEndpointState,
+ expectedErr: tcpip.ErrInvalidEndpointState,
+ expectedRequestDroppedErrors: 1,
},
}
@@ -543,10 +705,24 @@ func TestLinkAddressRequest(t *testing.T) {
// can mock a link address request and observe the packets sent to the
// link endpoint even though the stack uses the real NIC to validate the
// local address.
- if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr}
+ if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr {
t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
}
+ if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
+ t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent)
+ }
+ if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors {
+ t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors)
+ }
+ if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors {
+ t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors)
+ }
+ if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors {
+ t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors)
+ }
+
if test.expectedErr != nil {
return
}
@@ -561,6 +737,9 @@ func TestLinkAddressRequest(t *testing.T) {
}
rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op = %d, want = %d", got, header.ARPRequest)
+ }
if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr)
}
@@ -576,3 +755,22 @@ func TestLinkAddressRequest(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequestWithoutNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(arp.ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID)
+ }
+
+ if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 {
+ t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e9ff70d04..cc045c7a9 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -64,6 +64,7 @@ const (
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
+var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -87,6 +88,21 @@ type endpoint struct {
}
}
+// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint.
+func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
+ // handleControl expects the entire offending packet to be in the packet
+ // buffer's data field.
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+ pkt.NICID = e.nic.ID()
+ pkt.NetworkProtocolNumber = ProtocolNumber
+ // Use the same control type as an ICMPv4 destination host unreachable error
+ // since the host is considered unreachable if we cannot resolve the link
+ // address to the next hop.
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+}
+
// NewEndpoint creates a new ipv4 endpoint.
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index bbce1ef78..0ec0a0fef 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -645,29 +645,18 @@ func TestLinkResolution(t *testing.T) {
t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
- for {
- _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}})
- if resCh != nil {
- if err != tcpip.ErrNoLinkAddress {
- t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err)
- }
- for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
- {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
- } {
- routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
- if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
- t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
- }
- })
+ if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
+ t.Fatalf("ep.Write(_): %s", err)
+ }
+ for _, args := range []routeArgs{
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
+ {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
+ } {
+ routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
+ if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
+ t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
}
- <-resCh
- continue
- }
- if err != nil {
- t.Fatalf("ep.Write(_) = (_, _, %s)", err)
- }
- break
+ })
}
for _, args := range []routeArgs{
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f2018d073..2f82c3d5f 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -163,6 +163,7 @@ func getLabel(addr tcpip.Address) uint8 {
panic(fmt.Sprintf("should have a label for address = %s", addr))
}
+var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -224,6 +225,18 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint.
+func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
+ // handleControl expects the entire offending packet to be in the packet
+ // buffer's data field.
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+ pkt.NICID = e.nic.ID()
+ pkt.NetworkProtocolNumber = ProtocolNumber
+ e.handleControl(stack.ControlAddressUnreachable, 0, pkt)
+}
+
// onAddressAssignedLocked handles an address being assigned.
//
// Precondition: e.mu must be exclusively locked.
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 4777163cd..a7da9dcd9 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -82,7 +82,7 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) {
v.CapLength(n)
for len(v) > 0 {
- n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
if err != nil {
fmt.Println("Write failed:", err)
return
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 6883045b5..03b2f2d6f 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -83,7 +83,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
got, ch, err := c.get(addr, linkRes, "", nil, nil)
if err == tcpip.ErrWouldBlock {
if attemptedResolution {
- return got, tcpip.ErrNoLinkAddress
+ return got, tcpip.ErrTimeout
}
attemptedResolution = true
<-ch
@@ -253,8 +253,8 @@ func TestCacheResolutionFailed(t *testing.T) {
before := atomic.LoadUint32(&requestCount)
e.addr.Addr += "2"
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
+ t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
}
if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
@@ -269,8 +269,8 @@ func TestCacheResolutionTimeout(t *testing.T) {
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
e := testAddrs[0]
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
+ t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 4a34805b5..8a946b4fa 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -217,6 +217,16 @@ func (n *NIC) disableLocked() {
ep.Disable()
}
+ // Clear the neighbour table (including static entries) as we cannot guarantee
+ // that the current neighbour table will be valid when the NIC is enabled
+ // again.
+ //
+ // This matches linux's behaviour at the time of writing:
+ // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371
+ if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported {
+ panic(fmt.Sprintf("n.clearNeighbors(): %s", err))
+ }
+
if !n.setEnabled(false) {
panic("should have only done work to disable the NIC if it was enabled")
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 664cc6fa0..5f216ca21 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -268,17 +268,6 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
}
}
-// SourceLinkAddress returns the source link address of the packet.
-func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress {
- link := pk.LinkHeader().View()
-
- if link.IsEmpty() {
- return ""
- }
-
- return header.Ethernet(link).SourceAddress()
-}
-
// Network returns the network header as a header.Network.
//
// Network should only be called when NetworkHeader has been set.
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 4a3adcf33..bded8814e 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -101,10 +101,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
}
for _, p := range packets {
- if cancelled {
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
- } else if p.route.IsResolutionRequired() {
+ if cancelled || p.route.IsResolutionRequired() {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
+
+ if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok {
+ linkResolvableEP.HandleLinkResolutionFailure(pkt)
+ }
} else {
p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 4795208b4..924790779 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -55,7 +55,19 @@ type ControlType int
// The following are the allowed values for ControlType values.
// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlNetworkUnreachable ControlType = iota
+ // ControlAddressUnreachable indicates that an IPv6 packet did not reach its
+ // destination as the destination address was unreachable.
+ //
+ // This maps to the ICMPv6 Destination Ureachable Code 3 error; see
+ // RFC 4443 section 3.1 for more details.
+ ControlAddressUnreachable ControlType = iota
+ ControlNetworkUnreachable
+ // ControlNoRoute indicates that an IPv4 packet did not reach its destination
+ // because the destination host was unreachable.
+ //
+ // This maps to the ICMPv4 Destination Ureachable Code 1 error; see
+ // RFC 791's Destination Unreachable Message section (page 4) for more
+ // details.
ControlNoRoute
ControlPacketTooBig
ControlPortUnreachable
@@ -503,6 +515,13 @@ type NetworkInterface interface {
WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
+// LinkResolvableNetworkEndpoint handles link resolution events.
+type LinkResolvableNetworkEndpoint interface {
+ // HandleLinkResolutionFailure is called when link resolution prevents the
+ // argument from having been sent.
+ HandleLinkResolutionFailure(*PacketBuffer)
+}
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 856ebf6d4..4a3f937e3 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -4305,3 +4305,55 @@ func TestWritePacketToRemote(t *testing.T) {
}
})
}
+
+func TestClearNeighborCacheOnNICDisable(t *testing.T) {
+ const (
+ nicID = 1
+
+ ipv4Addr = tcpip.Address("\x01\x02\x03\x04")
+ ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04")
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 0, "")
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err)
+ }
+ if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err)
+ }
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 2 {
+ t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors)
+ }
+
+ // Disabling the NIC should clear the neighbor table.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+
+ // Enabling the NIC should have an empty neighbor table.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 0ff32c6ea..a2ab7537c 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -90,14 +90,14 @@ func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.Rea
return tcpip.ReadResult{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
- return 0, nil, tcpip.ErrNoRoute
+ return 0, tcpip.ErrNoRoute
}
v, err := p.FullPayload()
if err != nil {
- return 0, nil, err
+ return 0, err
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
@@ -105,10 +105,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
})
_ = pkt.TransportHeader().Push(fakeTransHeaderLen)
if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
- return 0, nil, err
+ return 0, err
}
- return int64(len(v)), nil, nil
+ return int64(len(v)), nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -222,7 +222,6 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
if err != nil {
return
}
- route.ResolveWith(pkt.SourceLinkAddress())
ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
@@ -522,8 +521,7 @@ func TestTransportSend(t *testing.T) {
// Create buffer that will hold the payload.
view := buffer.NewView(30)
- _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
- if err != nil {
+ if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("write failed: %v", err)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f798056c0..49d4912ad 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -104,7 +104,6 @@ var (
ErrConnectionAborted = &Error{msg: "connection aborted"}
ErrNoSuchFile = &Error{msg: "no such file"}
ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
- ErrNoLinkAddress = &Error{msg: "no remote link address"}
ErrBadAddress = &Error{msg: "bad address"}
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
ErrMessageTooLong = &Error{msg: "message too long"}
@@ -154,7 +153,6 @@ func StringToError(s string) *Error {
ErrConnectionAborted,
ErrNoSuchFile,
ErrInvalidOptionValue,
- ErrNoLinkAddress,
ErrBadAddress,
ErrNetworkUnreachable,
ErrMessageTooLong,
@@ -640,12 +638,7 @@ type Endpoint interface {
// stream (TCP) Endpoints may return partial writes, and even then only
// in the case where writing additional data would block. Other Endpoints
// will either write the entire message or return an error.
- //
- // For UDP and Ping sockets if address resolution is required,
- // ErrNoLinkAddress and a notification channel is returned for the caller to
- // block. Channel is closed once address resolution is complete (success or
- // not). The channel is only non-nil in this case.
- Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
+ Write(Payloader, WriteOptions) (int64, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -1598,6 +1591,59 @@ type IPStats struct {
OptionUnknownReceived *StatCounter
}
+// ARPStats collects ARP-specific stats.
+type ARPStats struct {
+ // PacketsReceived is the number of ARP packets received from the link layer.
+ PacketsReceived *StatCounter
+
+ // DisabledPacketsReceived is the number of ARP packets received from the link
+ // layer when the ARP layer is disabled.
+ DisabledPacketsReceived *StatCounter
+
+ // MalformedPacketsReceived is the number of ARP packets that were dropped due
+ // to being malformed.
+ MalformedPacketsReceived *StatCounter
+
+ // RequestsReceived is the number of ARP requests received.
+ RequestsReceived *StatCounter
+
+ // RequestsReceivedUnknownTargetAddress is the number of ARP requests that
+ // were targeted to an interface different from the one it was received on.
+ RequestsReceivedUnknownTargetAddress *StatCounter
+
+ // OutgoingRequestInterfaceHasNoLocalAddressErrors is the number of failures
+ // to send an ARP request because the interface has no network address
+ // assigned to it.
+ OutgoingRequestInterfaceHasNoLocalAddressErrors *StatCounter
+
+ // OutgoingRequestBadLocalAddressErrors is the number of failures to send an
+ // ARP request with a bad local address.
+ OutgoingRequestBadLocalAddressErrors *StatCounter
+
+ // OutgoingRequestNetworkUnreachableErrors is the number of failures to send
+ // an ARP request with a network unreachable error.
+ OutgoingRequestNetworkUnreachableErrors *StatCounter
+
+ // OutgoingRequestsDropped is the number of ARP requests which failed to write
+ // to a link-layer endpoint.
+ OutgoingRequestsDropped *StatCounter
+
+ // OutgoingRequestSent is the number of ARP requests successfully written to a
+ // link-layer endpoint.
+ OutgoingRequestsSent *StatCounter
+
+ // RepliesReceived is the number of ARP replies received.
+ RepliesReceived *StatCounter
+
+ // OutgoingRepliesDropped is the number of ARP replies which failed to write
+ // to a link-layer endpoint.
+ OutgoingRepliesDropped *StatCounter
+
+ // OutgoingRepliesSent is the number of ARP replies successfully written to a
+ // link-layer endpoint.
+ OutgoingRepliesSent *StatCounter
+}
+
// TCPStats collects TCP-specific stats.
type TCPStats struct {
// ActiveConnectionOpenings is the number of connections opened
@@ -1750,6 +1796,9 @@ type Stats struct {
// IP breaks out IP-specific stats (both v4 and v6).
IP IPStats
+ // ARP breaks out ARP-specific stats.
+ ARP ARPStats
+
// TCP breaks out TCP-specific stats.
TCP TCPStats
@@ -1784,9 +1833,6 @@ type SendErrors struct {
// NoRoute is the number of times we failed to resolve IP route.
NoRoute StatCounter
-
- // NoLinkAddr is the number of times we failed to resolve ARP.
- NoLinkAddr StatCounter
}
// ReadErrors collects segment read errors from an endpoint read call.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index ca1e88e99..1742a178d 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -31,5 +31,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 4c2084d19..49acd504e 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -157,13 +158,13 @@ func TestForwarding(t *testing.T) {
tests := []struct {
name string
- epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
}{
{
name: "IPv4 host1 server with host2 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
return endpointAndAddresses{
serverEP: ep1,
serverAddr: host1IPv4Addr.AddressWithPrefix.Address,
@@ -177,9 +178,9 @@ func TestForwarding(t *testing.T) {
},
{
name: "IPv6 host2 server with host1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
return endpointAndAddresses{
serverEP: ep1,
serverAddr: host2IPv6Addr.AddressWithPrefix.Address,
@@ -193,9 +194,9 @@ func TestForwarding(t *testing.T) {
},
{
name: "IPv4 host2 server with routerNIC1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber)
return endpointAndAddresses{
serverEP: ep1,
serverAddr: host2IPv4Addr.AddressWithPrefix.Address,
@@ -209,9 +210,9 @@ func TestForwarding(t *testing.T) {
},
{
name: "IPv6 routerNIC2 server with host1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
return endpointAndAddresses{
serverEP: ep1,
serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -225,202 +226,270 @@ func TestForwarding(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- }
-
- host1Stack := stack.New(stackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
-
- host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
- routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
-
- if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
- }
- if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
- }
- if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
- }
- if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
- }
-
- if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
- }
- if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
- }
-
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
- }
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
- }
-
- host1Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
- NIC: host1NICID,
- },
- {
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
- NIC: host1NICID,
- },
- })
- routerStack.SetRouteTable([]tcpip.Route{
- {
- Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID1,
- },
- {
- Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID1,
- },
- {
- Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID2,
- },
- {
- Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID2,
- },
- })
- host2Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
- NIC: host2NICID,
- },
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
- NIC: host2NICID,
- },
- })
-
- epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack)
- defer epsAndAddrs.serverEP.Close()
- defer epsAndAddrs.clientEP.Close()
-
- serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
- if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
- t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
- }
- clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
- if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
- }
-
- write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) {
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr *tcpip.Error
+ setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
t.Helper()
- dataPayload := tcpip.SlicePayload(data)
- wOpts := tcpip.WriteOptions{To: to}
- n, ch, err := ep.Write(dataPayload, wOpts)
- if err == tcpip.ErrNoLinkAddress {
- // Wait for link resolution to complete.
- <-ch
- n, _, err = ep.Write(dataPayload, wOpts)
- }
- if err != nil {
- t.Fatalf("ep.Write(_, _): %s", err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
}
- }
-
- data := []byte{1, 2, 3, 4}
- write(epsAndAddrs.clientEP, data, &serverAddr)
-
- read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress {
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: tcpip.ErrConnectStarted,
+ setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
t.Helper()
- // Wait for the endpoint to be readable.
- <-ch
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, len(data), opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
}
-
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: len(data),
- Total: len(data),
- RemoteAddr: tcpip.FullAddress{Addr: expectedFrom},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
-
- if t.Failed() {
- t.FailNow()
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if err == tcpip.ErrWouldBlock {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ return newEP, newCH
}
+ },
+ needRemoteAddr: false,
+ },
+ }
- return res.RemoteAddr
- }
-
- addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
- // Unspecify the NIC since NIC IDs are meaningless across stacks.
- addr.NIC = 0
-
- data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
- write(epsAndAddrs.serverEP, data, &addr)
- addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr)
- if addr.Port != listenPort {
- t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
+ routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
+
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ {
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ {
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ })
+ routerStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ {
+ Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ {
+ Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ {
+ Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ {
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ })
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+ defer epsAndAddrs.serverEP.Close()
+ defer epsAndAddrs.clientEP.Close()
+
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr {
+ t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr)
+ }
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ clientAddr = addr
+ clientAddr.NIC = 0
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ dataPayload := tcpip.SlicePayload(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(dataPayload, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ // Wait for the endpoint to be readable.
+ <-ch
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err := ep.Read(&buf, len(data), opts)
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ read(serverCH, serverEP, data, clientAddr)
+
+ data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ })
}
})
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index b4bffaec1..ed00c90d4 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -29,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,6 +56,13 @@ var (
PrefixLen: 8,
},
}
+ ipv4Addr3 = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.3").To4()),
+ PrefixLen: 8,
+ },
+ }
ipv6Addr1 = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -68,8 +77,65 @@ var (
PrefixLen: 64,
},
}
+ ipv6Addr3 = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::3").To16()),
+ PrefixLen: 64,
+ },
+ }
)
+func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) {
+ host1Stack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2)
+
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv4Addr1.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ {
+ Destination: ipv6Addr1.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv4Addr2.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ {
+ Destination: ipv6Addr2.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ })
+
+ return host1Stack, host2Stack
+}
+
// TestPing tests that two hosts can ping eachother when link resolution is
// enabled.
func TestPing(t *testing.T) {
@@ -128,51 +194,7 @@ func TestPing(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
}
- host1Stack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
-
- host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2)
-
- if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
- }
- if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
- }
-
- if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err)
- }
- if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err)
- }
-
- host1Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: ipv4Addr1.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: ipv6Addr1.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- })
- host2Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: ipv4Addr2.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: ipv6Addr2.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- })
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
var wq waiter.Queue
we, waiterCH := waiter.NewChannelEntry(nil)
@@ -183,19 +205,12 @@ func TestPing(t *testing.T) {
}
defer ep.Close()
- // The first write should trigger link resolution.
icmpBuf := test.icmpBuf(t)
wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
- if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress)
- } else {
- // Wait for link resolution to complete.
- <-ch
- }
- if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
} else if want := int64(len(icmpBuf)); n != want {
- t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
}
// Wait for the endpoint to be readable.
@@ -224,3 +239,159 @@ func TestPing(t *testing.T) {
})
}
}
+
+func TestTCPLinkResolutionFailure(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedWriteErr *tcpip.Error
+ sockError tcpip.SockError
+ }{
+ {
+ name: "IPv4 with resolvable remote",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ {
+ name: "IPv6 with resolvable remote",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ {
+ name: "IPv4 without resolvable remote",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ expectedWriteErr: tcpip.ErrNoRoute,
+ sockError: tcpip.SockError{
+ Err: tcpip.ErrNoRoute,
+ ErrType: byte(header.ICMPv4DstUnreachable),
+ ErrCode: byte(header.ICMPv4HostUnreachable),
+ ErrOrigin: tcpip.SockExtErrorOriginICMP,
+ Dst: tcpip.FullAddress{
+ NIC: host1NICID,
+ Addr: ipv4Addr3.AddressWithPrefix.Address,
+ Port: 1234,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: host1NICID,
+ Addr: ipv4Addr1.AddressWithPrefix.Address,
+ },
+ NetProto: ipv4.ProtocolNumber,
+ },
+ },
+ {
+ name: "IPv6 without resolvable remote",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ expectedWriteErr: tcpip.ErrNoRoute,
+ sockError: tcpip.SockError{
+ Err: tcpip.ErrNoRoute,
+ ErrType: byte(header.ICMPv6DstUnreachable),
+ ErrCode: byte(header.ICMPv6AddressUnreachable),
+ ErrOrigin: tcpip.SockExtErrorOriginICMP6,
+ Dst: tcpip.FullAddress{
+ NIC: host1NICID,
+ Addr: ipv6Addr3.AddressWithPrefix.Address,
+ Port: 1234,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: host1NICID,
+ Addr: ipv6Addr1.AddressWithPrefix.Address,
+ },
+ NetProto: ipv6.ProtocolNumber,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ }
+
+ host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ var listenerWQ waiter.Queue
+ listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
+ }
+ defer listenerEP.Close()
+
+ listenerAddr := tcpip.FullAddress{Port: 1234}
+ if err := listenerEP.Bind(listenerAddr); err != nil {
+ t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
+ }
+
+ if err := listenerEP.Listen(1); err != nil {
+ t.Fatalf("listenerEP.Listen(1): %s", err)
+ }
+
+ var clientWQ waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&we, waiter.EventOut|waiter.EventErr)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
+ }
+ defer clientEP.Close()
+
+ sockOpts := clientEP.SocketOptions()
+ sockOpts.SetRecvError(true)
+
+ remoteAddr := listenerAddr
+ remoteAddr.Addr = test.remoteAddr
+ if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted)
+ }
+
+ // Wait for an error due to link resolution failing, or the endpoint to be
+ // writable.
+ <-ch
+ var wOpts tcpip.WriteOptions
+ if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr {
+ t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr)
+ }
+
+ if test.expectedWriteErr == nil {
+ return
+ }
+
+ sockErr := sockOpts.DequeueErr()
+ if sockErr == nil {
+ t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil")
+ }
+
+ sockErrCmpOpts := []cmp.Option{
+ cmpopts.IgnoreUnexported(tcpip.SockError{}),
+ cmp.Comparer(func(a, b *tcpip.Error) bool {
+ // tcpip.Error holds an unexported field but the errors netstack uses
+ // are pre defined so we can simply compare pointers.
+ return a == b
+ }),
+ // Ignore the payload since we do not know the TCP seq/ack numbers.
+ checker.IgnoreCmpPath(
+ "Payload",
+ ),
+ }
+
+ if addr, err := clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("clientEP.GetLocalAddress(): %s", err)
+ } else {
+ test.sockError.Offender.Port = addr.Port
+ }
+ if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" {
+ t.Errorf("socket error mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index cb6169cfc..a59f25cc3 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -232,12 +232,12 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
Port: localPort,
},
}
- n, _, err := sep.Write(tcpip.SlicePayload(data), wopts)
+ n, err := sep.Write(tcpip.SlicePayload(data), wopts)
if err != nil {
t.Fatalf("sep.Write(_, _): %s", err)
}
if want := int64(len(data)); n != want {
- t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want)
+ t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
}
var buf bytes.Buffer
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index b42375695..eabc87938 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -587,10 +587,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
},
}
data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
- if n, _, err := wep.ep.Write(data, writeOpts); err != nil {
+ if n, err := wep.ep.Write(data, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
- t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
+ t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
}
for j, rep := range eps {
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 52cf89b54..76f7f54c6 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -197,10 +197,10 @@ func TestLocalPing(t *testing.T) {
payload := tcpip.SlicePayload(test.icmpBuf(t))
var wOpts tcpip.WriteOptions
- if n, _, err := ep.Write(payload, wOpts); err != nil {
+ if n, err := ep.Write(payload, wOpts); err != nil {
t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
} else if n != int64(len(payload)) {
- t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
+ t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
}
// Wait for the endpoint to become readable.
@@ -335,14 +335,14 @@ func TestLocalUDP(t *testing.T) {
wOpts := tcpip.WriteOptions{
To: &serverAddr,
}
- if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
- t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
+ if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
} else if subTest.expectedWriteErr != nil {
// Nothing else to test if we expected not to be able to send the
// UDP packet.
return
} else if n != int64(len(clientPayload)) {
- t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload))
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload))
}
}
@@ -382,10 +382,10 @@ func TestLocalUDP(t *testing.T) {
wOpts := tcpip.WriteOptions{
To: &clientAddr,
}
- if n, _, err := server.Write(serverPayload, wOpts); err != nil {
+ if n, err := server.Write(serverPayload, wOpts); err != nil {
t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
} else if n != int64(len(serverPayload)) {
- t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload))
+ t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
}
}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index c32fe5c4f..87277fbd3 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -236,8 +236,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- n, ch, err := e.write(p, opts)
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+ n, err := e.write(p, opts)
switch err {
case nil:
e.stats.PacketsSent.Increment()
@@ -247,8 +247,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.stats.WriteErrors.WriteClosed.Increment()
case tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoLinkAddress:
- e.stats.SendErrors.NoLinkAddr.Increment()
case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
@@ -256,13 +254,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// For all other errors when writing to the network layer.
e.stats.SendErrors.SendToNetworkFailed.Increment()
}
- return n, ch, err
+ return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, nil, tcpip.ErrInvalidOptionValue
+ return 0, tcpip.ErrInvalidOptionValue
}
to := opts.To
@@ -272,14 +270,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
if err != nil {
- return 0, nil, err
+ return 0, err
}
if !retry {
@@ -294,7 +292,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID := to.NIC
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, nil, tcpip.ErrNoRoute
+ return 0, tcpip.ErrNoRoute
}
nicID = e.BindNICID
@@ -302,31 +300,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
- return 0, nil, err
+ return 0, err
}
// Find the endpoint.
r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
if err != nil {
- return 0, nil, err
+ return 0, err
}
defer r.Release()
route = r
}
- if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- return 0, ch, tcpip.ErrNoLinkAddress
- }
- return 0, nil, err
- }
- }
-
v, err := p.FullPayload()
if err != nil {
- return 0, nil, err
+ return 0, err
}
switch e.NetProto {
@@ -338,10 +327,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if err != nil {
- return 0, nil, err
+ return 0, err
}
- return int64(len(v)), nil, nil
+ return int64(len(v)), nil
}
// SetSockOpt sets a socket option.
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 3ab060751..c3b3b8d34 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -207,9 +207,9 @@ func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpi
return res, nil
}
-func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
// TODO(gvisor.dev/issue/173): Implement.
- return 0, nil, tcpip.ErrInvalidOptionValue
+ return 0, tcpip.ErrInvalidOptionValue
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index dd260535f..425bcf3ee 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -234,20 +234,20 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
}
// Write implements tcpip.Endpoint.Write.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
// We can create, but not write to, unassociated IPv6 endpoints.
if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
- return 0, nil, tcpip.ErrInvalidOptionValue
+ return 0, tcpip.ErrInvalidOptionValue
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
- return 0, nil, tcpip.ErrInvalidOptionValue
+ return 0, tcpip.ErrInvalidOptionValue
}
}
- n, ch, err := e.write(p, opts)
+ n, err := e.write(p, opts)
switch err {
case nil:
e.stats.PacketsSent.Increment()
@@ -257,8 +257,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.stats.WriteErrors.WriteClosed.Increment()
case tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoLinkAddress:
- e.stats.SendErrors.NoLinkAddr.Increment()
case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
@@ -266,25 +264,25 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// For all other errors when writing to the network layer.
e.stats.SendErrors.SendToNetworkFailed.Increment()
}
- return n, ch, err
+ return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
- return 0, nil, tcpip.ErrInvalidOptionValue
+ return 0, tcpip.ErrInvalidOptionValue
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.closed {
- return 0, nil, tcpip.ErrInvalidEndpointState
+ return 0, tcpip.ErrInvalidEndpointState
}
payloadBytes, err := p.FullPayload()
if err != nil {
- return 0, nil, err
+ return 0, err
}
// If this is an unassociated socket and callee provided a nonzero
@@ -292,7 +290,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- return 0, nil, tcpip.ErrInvalidOptionValue
+ return 0, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
// Update dstAddr with the address in the IP header, unless
@@ -313,7 +311,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- return 0, nil, tcpip.ErrDestinationRequired
+ return 0, tcpip.ErrDestinationRequired
}
return e.finishWrite(payloadBytes, e.route)
@@ -323,42 +321,30 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- return 0, nil, tcpip.ErrNoRoute
+ return 0, tcpip.ErrNoRoute
}
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- return 0, nil, err
+ return 0, err
}
- n, ch, err := e.finishWrite(payloadBytes, route)
+ n, err := e.finishWrite(payloadBytes, route)
route.Release()
- return n, ch, err
+ return n, err
}
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
- // We may need to resolve the route (match a link layer address to the
- // network address). If that requires blocking (e.g. to use ARP),
- // return a channel on which the caller can wait.
- if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- return 0, ch, tcpip.ErrNoLinkAddress
- }
- return 0, nil, err
- }
- }
-
+func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) {
if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
- return 0, nil, err
+ return 0, err
}
} else {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -371,11 +357,11 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
}, pkt); err != nil {
- return 0, nil, err
+ return 0, err
}
}
- return int64(len(payloadBytes)), nil, nil
+ return int64(len(payloadBytes)), nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 2d96a65bd..6921de0f1 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -210,7 +210,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
if err != nil {
return nil, err
}
- route.ResolveWith(s.remoteLinkAddr)
n := newEndpoint(l.stack, netProto, queue)
n.ops.SetV6Only(l.v6Only)
@@ -306,10 +305,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// Initialize and start the handshake.
h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
- if err := h.start(); err != nil {
- l.cleanupFailedHandshake(h)
- return nil, err
- }
+ h.start()
return h, nil
}
@@ -573,7 +569,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
return err
}
defer route.Release()
- route.ResolveWith(s.remoteLinkAddr)
// Send SYN without window scaling because we currently
// don't encode this information in the cookie.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index a00ef97c6..6cdbb8bee 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -53,7 +53,6 @@ const (
wakerForNotification = iota
wakerForNewSegment
wakerForResend
- wakerForResolution
)
const (
@@ -460,66 +459,9 @@ func (h *handshake) processSegments() *tcpip.Error {
return nil
}
-func (h *handshake) resolveRoute() *tcpip.Error {
- // Set up the wakers.
- var s sleep.Sleeper
- resolutionWaker := &sleep.Waker{}
- s.AddWaker(resolutionWaker, wakerForResolution)
- s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
- defer s.Done()
-
- // Initial action is to resolve route.
- index := wakerForResolution
- attemptedResolution := false
- for {
- switch index {
- case wakerForResolution:
- if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock {
- if err != nil {
- h.ep.stats.SendErrors.NoRoute.Increment()
- }
- // Either success (err == nil) or failure.
- return err
- }
- if attemptedResolution {
- h.ep.stats.SendErrors.NoLinkAddr.Increment()
- return tcpip.ErrNoLinkAddress
- }
- attemptedResolution = true
- // Resolution not completed. Keep trying...
-
- case wakerForNotification:
- n := h.ep.fetchNotifications()
- if n&notifyClose != 0 {
- return tcpip.ErrAborted
- }
- if n&notifyDrain != 0 {
- close(h.ep.drainDone)
- h.ep.mu.Unlock()
- <-h.ep.undrain
- h.ep.mu.Lock()
- }
- if n&notifyError != 0 {
- return h.ep.lastErrorLocked()
- }
- }
-
- // Wait for notification.
- h.ep.mu.Unlock()
- index, _ = s.Fetch(true /* block */)
- h.ep.mu.Lock()
- }
-}
-
-// start resolves the route if necessary and sends the first
-// SYN/SYN-ACK.
-func (h *handshake) start() *tcpip.Error {
- if h.ep.route.IsResolutionRequired() {
- if err := h.resolveRoute(); err != nil {
- return err
- }
- }
-
+// start sends the first SYN/SYN-ACK. It does not block, even if link address
+// resolution is required.
+func (h *handshake) start() {
h.startTime = time.Now()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
var sackEnabled tcpip.TCPSACKEnabled
@@ -560,7 +502,6 @@ func (h *handshake) start() *tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, synOpts)
- return nil
}
// complete completes the TCP 3-way handshake initiated by h.start().
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 25b180fa5..a4508e871 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1507,7 +1507,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
@@ -1520,7 +1520,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.sndBufMu.Unlock()
e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
+ return 0, err
}
// We can release locks while copying data.
@@ -1541,7 +1541,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.sndBufMu.Unlock()
e.UnlockUser()
}
- return 0, nil, perr
+ return 0, perr
}
if !opts.Atomic {
@@ -1555,7 +1555,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.sndBufMu.Unlock()
e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
+ return 0, err
}
// Discard any excess data copied in due to avail being reduced due
@@ -1575,7 +1575,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Do the work inline.
e.handleWrite()
e.UnlockUser()
- return int64(len(v)), nil, nil
+ return int64(len(v)), nil
}
// selectWindowLocked returns the new window without checking for shrinking or scaling
@@ -2325,68 +2325,17 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
if run {
- if err := e.startMainLoop(handshake); err != nil {
- return err
- }
- }
-
- return tcpip.ErrConnectStarted
-}
-
-// startMainLoop sends the initial SYN and starts the main loop for the
-// endpoint.
-func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
- preloop := func() *tcpip.Error {
if handshake {
h := e.newHandshake()
e.setEndpointState(StateSynSent)
- if err := h.start(); err != nil {
- e.lastErrorMu.Lock()
- e.lastError = err
- e.lastErrorMu.Unlock()
-
- e.setEndpointState(StateError)
- e.hardError = err
-
- // Call cleanupLocked to free up any reservations.
- e.cleanupLocked()
- return err
- }
+ h.start()
}
e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- return nil
- }
-
- if e.route.IsResolutionRequired() {
- // If the endpoint is closed between releasing e.mu and the goroutine below
- // acquiring it, make sure that cleanup is deferred to the new goroutine.
e.workerRunning = true
-
- // Sending the initial SYN may block due to route resolution; do it in a
- // separate goroutine to avoid blocking the syscall goroutine.
- go func() { // S/R-SAFE: will be drained before save.
- e.mu.Lock()
- if err := preloop(); err != nil {
- e.workerRunning = false
- e.mu.Unlock()
- return
- }
- e.mu.Unlock()
- _ = e.protocolMainLoop(handshake, nil)
- }()
- return nil
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
- // No route resolution is required, so we can send the initial SYN here without
- // blocking. This will hopefully reduce overall latency by overlapping time
- // spent waiting for a SYN-ACK and time spent spinning up a new goroutine
- // for the main loop.
- if err := preloop(); err != nil {
- return err
- }
- e.workerRunning = true
- go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
- return nil
+ return tcpip.ErrConnectStarted
}
// ConnectEndpoint is not supported.
@@ -2779,6 +2728,9 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt
case stack.ControlNoRoute:
e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
+ case stack.ControlAddressUnreachable:
+ e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt)
+
case stack.ControlNetworkUnreachable:
e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index c9e194f82..1720370c9 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -222,7 +222,6 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
return err
}
defer route.Release()
- route.ResolveWith(s.remoteLinkAddr)
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index c5a6d2fba..7cca4def5 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -49,11 +49,10 @@ type segment struct {
// TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of
// individual members for link/network packet info.
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- netProto tcpip.NetworkProtocolNumber
- nicID tcpip.NICID
- remoteLinkAddr tcpip.LinkAddress
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ netProto tcpip.NetworkProtocolNumber
+ nicID tcpip.NICID
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
@@ -89,13 +88,12 @@ type segment struct {
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
netHdr := pkt.Network()
s := &segment{
- refCnt: 1,
- id: id,
- srcAddr: netHdr.SourceAddress(),
- dstAddr: netHdr.DestinationAddress(),
- netProto: pkt.NetworkProtocolNumber,
- nicID: pkt.NICID,
- remoteLinkAddr: pkt.SourceLinkAddress(),
+ refCnt: 1,
+ id: id,
+ srcAddr: netHdr.SourceAddress(),
+ dstAddr: netHdr.DestinationAddress(),
+ netProto: pkt.NetworkProtocolNumber,
+ nicID: pkt.NICID,
}
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
@@ -128,7 +126,6 @@ func (s *segment) clone() *segment {
window: s.window,
netProto: s.netProto,
nicID: s.nicID,
- remoteLinkAddr: s.remoteLinkAddr,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index b9993ce1a..f7aaee23f 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -49,7 +49,7 @@ func TestFastRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -214,7 +214,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -256,7 +256,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -361,7 +361,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -470,11 +470,11 @@ func TestRetransmit(t *testing.T) {
// Write all the data in two shots. Packets will only be written at the
// MTU size though.
half := data[:len(data)/2]
- if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
half = data[len(data)/2:]
- if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 9818ffa0f..342eb5eb8 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -68,7 +68,7 @@ func TestRACKUpdate(t *testing.T) {
// Write the data.
xmitTime = time.Now()
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -120,7 +120,7 @@ func TestRACKDetectReorder(t *testing.T) {
}
// Write the data.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -151,7 +151,7 @@ func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.Vie
}
// Write the data.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index faf0c0ad7..6635bb815 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -402,7 +402,7 @@ func TestSACKRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index aeceee7e0..729bf7ef5 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1348,7 +1348,7 @@ func TestTOSV4(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -1397,7 +1397,7 @@ func TestTrafficClassV6(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -1977,7 +1977,11 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Keep the payload size < segment overhead and such that it is a multiple
// of the window scaled value. This enables the test to perform equality
// checks on the incoming receive window.
- payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale))
+ payloadSize := 1 << c.RcvdWindowScale
+ if payloadSize >= tcp.SegSize {
+ t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize)
+ }
+ payload := generateRandomPayload(t, payloadSize)
payloadLen := seqnum.Size(len(payload))
iss := seqnum.Value(789)
seqNum := iss.Add(1)
@@ -2173,7 +2177,7 @@ func TestSimpleSend(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2214,8 +2218,7 @@ func TestZeroWindowSend(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
- if err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2283,7 +2286,7 @@ func TestScaledWindowConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2315,7 +2318,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2389,7 +2392,7 @@ func TestScaledWindowAccept(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2463,7 +2466,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2626,7 +2629,7 @@ func TestSegmentMerging(t *testing.T) {
// anymore packets from going out.
for i := 0; i < tcp.InitialCwnd; i++ {
view := buffer.NewViewFromBytes([]byte{0})
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2637,7 +2640,7 @@ func TestSegmentMerging(t *testing.T) {
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2707,7 +2710,7 @@ func TestDelay(t *testing.T) {
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2754,7 +2757,7 @@ func TestUndelay(t *testing.T) {
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2838,7 +2841,7 @@ func TestMSSNotDelayed(t *testing.T) {
allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2889,7 +2892,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3321,7 +3324,7 @@ func TestSendOnResetConnection(t *testing.T) {
// Try to write.
view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
@@ -3344,7 +3347,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
defer c.WQ.EventUnregister(&waitEntry)
- _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3401,7 +3404,7 @@ func TestMaxRTO(t *testing.T) {
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
- _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3450,7 +3453,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
}
- if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
pkt := c.GetPacket()
@@ -3588,7 +3591,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3661,7 +3664,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// any of them.
view := buffer.NewView(10)
for i := tcp.InitialCwnd; i > 0; i-- {
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
}
@@ -3747,7 +3750,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3773,7 +3776,7 @@ func TestFinWithPendingData(t *testing.T) {
})
// Write new data, but don't acknowledge it.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3834,7 +3837,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3871,7 +3874,7 @@ func TestFinWithPartialAck(t *testing.T) {
)
// Write new data, but don't acknowledge it.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3978,7 +3981,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4035,9 +4038,6 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
}
- if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0)
- }
}
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
@@ -4607,7 +4607,7 @@ func TestSelfConnect(t *testing.T) {
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4785,7 +4785,7 @@ func TestPathMTUDiscovery(t *testing.T) {
data[i] = byte(i)
}
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5074,7 +5074,7 @@ func TestKeepalive(t *testing.T) {
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5903,9 +5903,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
-
- if err != nil {
+ if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7103,7 +7101,7 @@ func TestTCPCloseWithData(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7202,7 +7200,7 @@ func TestTCPUserTimeout(t *testing.T) {
// Send some data and wait before ACKing it.
view := buffer.NewView(3)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 9e02d467d..88fb054bb 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -154,7 +154,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
@@ -217,7 +217,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
view := buffer.NewView(len(data))
copy(view, data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 5d87f3a7e..520a0ac9d 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -417,8 +417,8 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- n, ch, err := e.write(p, opts)
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+ n, err := e.write(p, opts)
switch err {
case nil:
e.stats.PacketsSent.Increment()
@@ -428,8 +428,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.stats.WriteErrors.WriteClosed.Increment()
case tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoLinkAddress:
- e.stats.SendErrors.NoLinkAddr.Increment()
case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
@@ -437,17 +435,17 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// For all other errors when writing to the network layer.
e.stats.SendErrors.SendToNetworkFailed.Increment()
}
- return n, ch, err
+ return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return 0, nil, err
+ return 0, err
}
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, nil, tcpip.ErrInvalidOptionValue
+ return 0, tcpip.ErrInvalidOptionValue
}
to := opts.To
@@ -463,14 +461,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
if err != nil {
- return 0, nil, err
+ return 0, err
}
if !retry {
@@ -486,7 +484,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID := to.NIC
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, nil, tcpip.ErrNoRoute
+ return 0, tcpip.ErrNoRoute
}
nicID = e.BindNICID
@@ -494,17 +492,17 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if to.Port == 0 {
// Port 0 is an invalid port to send to.
- return 0, nil, tcpip.ErrInvalidEndpointState
+ return 0, tcpip.ErrInvalidEndpointState
}
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
- return 0, nil, err
+ return 0, err
}
r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
- return 0, nil, err
+ return 0, err
}
defer r.Release()
@@ -513,21 +511,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return 0, nil, tcpip.ErrBroadcastDisabled
- }
-
- if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- return 0, ch, tcpip.ErrNoLinkAddress
- }
- return 0, nil, err
- }
+ return 0, tcpip.ErrBroadcastDisabled
}
v, err := p.FullPayload()
if err != nil {
- return 0, nil, err
+ return 0, err
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
@@ -545,7 +534,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
v,
)
}
- return 0, nil, tcpip.ErrMessageTooLong
+ return 0, tcpip.ErrMessageTooLong
}
ttl := e.ttl
@@ -575,9 +564,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
- return 0, nil, err
+ return 0, err
}
- return int64(len(v)), nil, nil
+ return int64(len(v)), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index d7fc21f11..49e673d58 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -75,7 +75,6 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
if err != nil {
return nil, err
}
- route.ResolveWith(r.pkt.SourceLinkAddress())
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c8da173f1..52403ed78 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -967,7 +967,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
payload := buffer.View(newPayload())
- _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
c.checkEndpointWriteStats(1, epstats, gotErr)
@@ -1008,7 +1008,7 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View
}
}
payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %s", err)
}
@@ -1184,7 +1184,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
}
payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
}
@@ -2317,8 +2317,6 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE
want.WriteErrors.WriteClosed.IncrementBy(incr)
case tcpip.ErrInvalidEndpointState:
want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
- case tcpip.ErrNoLinkAddress:
- want.SendErrors.NoLinkAddr.IncrementBy(incr)
case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
want.SendErrors.NoRoute.IncrementBy(incr)
default:
@@ -2510,20 +2508,20 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
expectedErrWithoutBcastOpt = nil
}
- if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
- t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
}
ep.SocketOptions().SetBroadcast(true)
- if n, _, err := ep.Write(data, opts); err != nil {
- t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
+ if n, err := ep.Write(data, opts); err != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
}
ep.SocketOptions().SetBroadcast(false)
- if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
- t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
}
})
}
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index fdd416b5e..a35c7ffa6 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -83,11 +83,10 @@ func ConfigureExePath() error {
// TmpDir returns the absolute path to a writable directory that can be used as
// scratch by the test.
func TmpDir() string {
- dir := os.Getenv("TEST_TMPDIR")
- if dir == "" {
- dir = "/tmp"
+ if dir, ok := os.LookupEnv("TEST_TMPDIR"); ok {
+ return dir
}
- return dir
+ return "/tmp"
}
// Logger is a simple logging wrapper.
@@ -111,6 +110,30 @@ func (d DefaultLogger) Logf(fmt string, args ...interface{}) {
log.Printf(fmt, args...)
}
+// multiLogger logs to multiple Loggers.
+type multiLogger []Logger
+
+// Name implements Logger.Name.
+func (m multiLogger) Name() string {
+ names := make([]string, len(m))
+ for i, l := range m {
+ names[i] = l.Name()
+ }
+ return strings.Join(names, "+")
+}
+
+// Logf implements Logger.Logf.
+func (m multiLogger) Logf(fmt string, args ...interface{}) {
+ for _, l := range m {
+ l.Logf(fmt, args...)
+ }
+}
+
+// NewMultiLogger returns a new Logger that logs on multiple Loggers.
+func NewMultiLogger(loggers ...Logger) Logger {
+ return multiLogger(loggers)
+}
+
// Cmd is a simple wrapper.
type Cmd struct {
logger Logger
@@ -519,7 +542,7 @@ func IsStatic(filename string) (bool, error) {
//
// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner.
func TouchShardStatusFile() error {
- if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" {
+ if statusFile, ok := os.LookupEnv("TEST_SHARD_STATUS_FILE"); ok {
cmd := exec.Command("touch", statusFile)
if b, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error())
@@ -541,8 +564,9 @@ func TestIndicesForShard(numTests int) ([]int, error) {
shardTotal = 1
)
- indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
- if indexStr != "" && totalStr != "" {
+ indexStr, indexOk := os.LookupEnv("TEST_SHARD_INDEX")
+ totalStr, totalOk := os.LookupEnv("TEST_TOTAL_SHARDS")
+ if indexOk && totalOk {
// Parse index and total to ints.
var err error
shardIndex, err = strconv.Atoi(indexStr)
diff --git a/runsc/cmd/gofer_test.go b/runsc/cmd/gofer_test.go
index cbea7f127..fea62a4f4 100644
--- a/runsc/cmd/gofer_test.go
+++ b/runsc/cmd/gofer_test.go
@@ -24,11 +24,10 @@ import (
)
func tmpDir() string {
- dir := os.Getenv("TEST_TMPDIR")
- if dir == "" {
- dir = "/tmp"
+ if dir, ok := os.LookupEnv("TEST_TMPDIR"); ok {
+ return dir
}
- return dir
+ return "/tmp"
}
type dir struct {
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index 02ab9255a..7e738dfdf 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -114,7 +114,7 @@ func NewFromFlags() (*Config, error) {
if len(conf.RootDir) == 0 {
// If not set, set default root dir to something (hopefully) user-writeable.
conf.RootDir = "/var/run/runsc"
- if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
+ if runtimeDir, ok := os.LookupEnv("XDG_RUNTIME_DIR"); ok {
conf.RootDir = filepath.Join(runtimeDir, "runsc")
}
}
diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go
index fbf134014..c26dc8577 100644
--- a/test/root/crictl_test.go
+++ b/test/root/crictl_test.go
@@ -353,8 +353,8 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) {
// because the shims will be installed there, and containerd may infer
// the binary name and search the PATH.
runtimeDir := path.Dir(runtime)
- modifiedPath := os.Getenv("PATH")
- if modifiedPath != "" {
+ modifiedPath, ok := os.LookupEnv("PATH")
+ if ok {
modifiedPath = ":" + modifiedPath // We prepend below.
}
modifiedPath = path.Dir(getContainerd()) + modifiedPath
diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc
index 27758203d..11fbfa5c5 100644
--- a/test/syscalls/linux/mkdir.cc
+++ b/test/syscalls/linux/mkdir.cc
@@ -82,6 +82,39 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) {
SyscallFailsWithErrno(EACCES));
}
+TEST_F(MkdirTest, DirAlreadyExists) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
+ auto dir = JoinPath(dirname_.c_str(), "foo");
+ EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallSucceeds());
+
+ struct {
+ int mode;
+ int err;
+ } tests[] = {
+ {.mode = 0000, .err = EACCES}, // No perm
+ {.mode = 0100, .err = EEXIST}, // Exec only
+ {.mode = 0200, .err = EACCES}, // Write only
+ {.mode = 0300, .err = EEXIST}, // Write+exec
+ {.mode = 0400, .err = EACCES}, // Read only
+ {.mode = 0500, .err = EEXIST}, // Read+exec
+ {.mode = 0600, .err = EACCES}, // Read+write
+ {.mode = 0700, .err = EEXIST}, // All
+ };
+ for (const auto& t : tests) {
+ printf("mode: 0%o\n", t.mode);
+ EXPECT_THAT(chmod(dirname_.c_str(), t.mode), SyscallSucceeds());
+ EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(t.err));
+ }
+
+ // Clean up.
+ EXPECT_THAT(chmod(dirname_.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(rmdir(dir.c_str()), SyscallSucceeds());
+}
+
TEST_F(MkdirTest, MkdirAtEmptyPath) {
ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
auto fd =
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index c2369db54..e5730a606 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -483,6 +483,112 @@ TEST(SpliceTest, TwoPipes) {
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
}
+TEST(SpliceTest, TwoPipesPartialRead) {
+ // Create two pipes.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor first_rfd(fds[0]);
+ const FileDescriptor first_wfd(fds[1]);
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor second_rfd(fds[0]);
+ const FileDescriptor second_wfd(fds[1]);
+
+ // Write half a page of data to the first pipe.
+ std::vector<char> buf(kPageSize / 2);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize / 2));
+
+ // Attempt to splice one page from the first pipe to the second; it should
+ // immediately return after splicing the half-page previously written to the
+ // first pipe.
+ EXPECT_THAT(
+ splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize / 2));
+}
+
+TEST(SpliceTest, TwoPipesPartialWrite) {
+ // Create two pipes.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor first_rfd(fds[0]);
+ const FileDescriptor first_wfd(fds[1]);
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor second_rfd(fds[0]);
+ const FileDescriptor second_wfd(fds[1]);
+
+ // Write two pages of data to the first pipe.
+ std::vector<char> buf(2 * kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(2 * kPageSize));
+
+ // Limit the second pipe to two pages, then write one page of data to it.
+ ASSERT_THAT(fcntl(second_wfd.get(), F_SETPIPE_SZ, 2 * kPageSize),
+ SyscallSucceeds());
+ ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size() / 2),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Attempt to splice two pages from the first pipe to the second; it should
+ // immediately return after splicing the first page previously written to the
+ // first pipe.
+ EXPECT_THAT(splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr,
+ 2 * kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+}
+
+TEST(TeeTest, TwoPipesPartialRead) {
+ // Create two pipes.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor first_rfd(fds[0]);
+ const FileDescriptor first_wfd(fds[1]);
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor second_rfd(fds[0]);
+ const FileDescriptor second_wfd(fds[1]);
+
+ // Write half a page of data to the first pipe.
+ std::vector<char> buf(kPageSize / 2);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize / 2));
+
+ // Attempt to tee one page from the first pipe to the second; it should
+ // immediately return after copying the half-page previously written to the
+ // first pipe.
+ EXPECT_THAT(tee(first_rfd.get(), second_wfd.get(), kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize / 2));
+}
+
+TEST(TeeTest, TwoPipesPartialWrite) {
+ // Create two pipes.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor first_rfd(fds[0]);
+ const FileDescriptor first_wfd(fds[1]);
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor second_rfd(fds[0]);
+ const FileDescriptor second_wfd(fds[1]);
+
+ // Write two pages of data to the first pipe.
+ std::vector<char> buf(2 * kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(2 * kPageSize));
+
+ // Limit the second pipe to two pages, then write one page of data to it.
+ ASSERT_THAT(fcntl(second_wfd.get(), F_SETPIPE_SZ, 2 * kPageSize),
+ SyscallSucceeds());
+ ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size() / 2),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Attempt to tee two pages from the first pipe to the second; it should
+ // immediately return after copying the first page previously written to the
+ // first pipe.
+ EXPECT_THAT(tee(first_rfd.get(), second_wfd.get(), 2 * kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+}
+
TEST(SpliceTest, TwoPipesCircular) {
// This test deadlocks the sentry on VFS1 because VFS1 splice ordering is
// based on fs.File.UniqueID, which does not prevent circular ordering between
diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go
index 9f0092ed6..707a3a8f8 100644
--- a/website/cmd/server/main.go
+++ b/website/cmd/server/main.go
@@ -366,7 +366,7 @@ func registerProfile(mux *http.ServeMux) {
}
func envFlagString(name, def string) string {
- if val := os.Getenv(name); val != "" {
+ if val, ok := os.LookupEnv(name); ok {
return val
}
return def