diff options
Diffstat (limited to 'pkg')
53 files changed, 1465 insertions, 1085 deletions
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index df27554d3..91d5dc174 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -407,33 +407,44 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err != nil { return err } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + + // Order of checks is important. First check if parent directory can be + // executed, then check for existence, and lastly check if mount is writable. + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return err } name := rp.Component() if name == "." || name == ".." { return syserror.EEXIST } - if len(name) > maxFilenameLen { - return syserror.ENAMETOOLONG - } if parent.isDeleted() { return syserror.ENOENT } + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + + child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds) + switch { + case err != nil && err != syserror.ENOENT: + return err + case child != nil: + return syserror.EEXIST + } + mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { return err } defer mnt.EndWrite() - parent.dirMu.Lock() - defer parent.dirMu.Unlock() + + if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return err + } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } if parent.isSynthetic() { - if child := parent.children[name]; child != nil { - return syserror.EEXIST - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if createInSyntheticDir == nil { return syserror.EPERM } @@ -449,47 +460,20 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } - if fs.opts.interop == InteropModeShared { - if child := parent.children[name]; child != nil && child.isSynthetic() { - return syserror.EEXIST - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } - // The existence of a non-synthetic dentry at name would be inconclusive - // because the file it represents may have been deleted from the remote - // filesystem, so we would need to make an RPC to revalidate the dentry. - // Just attempt the file creation RPC instead. If a file does exist, the - // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a - // stale dentry exists, the dentry will fail revalidation next time it's - // used. - if err := createInRemoteDir(parent, name, &ds); err != nil { - return err - } - ev := linux.IN_CREATE - if dir { - ev |= linux.IN_ISDIR - } - parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) - return nil - } - if child := parent.children[name]; child != nil { - return syserror.EEXIST - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } - // No cached dentry exists; however, there might still be an existing file - // at name. As above, we attempt the file creation RPC anyway. + // No cached dentry exists; however, in InteropModeShared there might still be + // an existing file at name. Just attempt the file creation RPC anyways. If a + // file does exist, the RPC will fail with EEXIST like we would have. if err := createInRemoteDir(parent, name, &ds); err != nil { return err } - if child, ok := parent.children[name]; ok && child == nil { - // Delete the now-stale negative dentry. - delete(parent.children, name) + if fs.opts.interop != InteropModeShared { + if child, ok := parent.children[name]; ok && child == nil { + // Delete the now-stale negative dentry. + delete(parent.children, name) + } + parent.touchCMtime() + parent.dirents = nil } - parent.touchCMtime() - parent.dirents = nil ev := linux.IN_CREATE if dir { ev |= linux.IN_ISDIR diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index e77523f22..a7a553619 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -208,7 +208,9 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // * Filesystem.mu must be locked for at least reading. // * isDir(parentInode) == true. func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error { - if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil { + // Order of checks is important. First check if parent directory can be + // executed, then check for existence, and lastly check if mount is writable. + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayExec); err != nil { return err } if name == "." || name == ".." { @@ -223,6 +225,9 @@ func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string if parent.VFSDentry().IsDead() { return syserror.ENOENT } + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite); err != nil { + return err + } return nil } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index d55bdc97f..e46f593c7 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -480,9 +480,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err != nil { return err } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return err - } name := rp.Component() if name == "." || name == ".." { return syserror.EEXIST @@ -490,11 +487,11 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if parent.vfsd.IsDead() { return syserror.ENOENT } - mnt := rp.Mount() - if err := mnt.CheckBeginWrite(); err != nil { + + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return err } - defer mnt.EndWrite() + parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -514,6 +511,14 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.ENOENT } + mnt := rp.Mount() + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return err + } // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. if err := parent.copyUpLocked(ctx); err != nil { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 9296db2fb..453e41d11 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -153,7 +153,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err != nil { return err } - if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + + // Order of checks is important. First check if parent directory can be + // executed, then check for existence, and lastly check if mount is writable. + if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return err } name := rp.Component() @@ -179,6 +182,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return err } defer mnt.EndWrite() + + if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return err + } if err := create(parentDir, name); err != nil { return err } diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 99134e634..2c32d017d 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -12,6 +12,7 @@ go_library( "pipe_util.go", "reader.go", "reader_writer.go", + "save_restore.go", "vfs.go", "writer.go", ], @@ -19,7 +20,6 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/amutex", - "//pkg/buffer", "//pkg/context", "//pkg/marshal/primitive", "//pkg/safemem", diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index b989e14c7..c551acd99 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -21,8 +21,8 @@ import ( "sync/atomic" "syscall" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -75,10 +75,18 @@ type Pipe struct { // mu protects all pipe internal state below. mu sync.Mutex `state:"nosave"` - // view is the underlying set of buffers. + // buf holds the pipe's data. buf is a circular buffer; the first valid + // byte in buf is at offset off, and the pipe contains size valid bytes. + // bufBlocks contains two identical safemem.Blocks representing buf; this + // avoids needing to heap-allocate a new safemem.Block slice when buf is + // resized. bufBlockSeq is a safemem.BlockSeq representing bufBlocks. // - // This is protected by mu. - view buffer.View + // These fields are protected by mu. + buf []byte + bufBlocks [2]safemem.Block `state:"nosave"` + bufBlockSeq safemem.BlockSeq `state:"nosave"` + off int64 + size int64 // max is the maximum size of the pipe in bytes. When this max has been // reached, writers will get EWOULDBLOCK. @@ -99,12 +107,6 @@ type Pipe struct { // // N.B. The size will be bounded. func NewPipe(isNamed bool, sizeBytes int64) *Pipe { - if sizeBytes < MinimumPipeSize { - sizeBytes = MinimumPipeSize - } - if sizeBytes > MaximumPipeSize { - sizeBytes = MaximumPipeSize - } var p Pipe initPipe(&p, isNamed, sizeBytes) return &p @@ -175,75 +177,71 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } -type readOps struct { - // left returns the bytes remaining. - left func() int64 - - // limit limits subsequence reads. - limit func(int64) - - // read performs the actual read operation. - read func(*buffer.View) (int64, error) -} - -// read reads data from the pipe into dst and returns the number of bytes -// read, or returns ErrWouldBlock if the pipe is empty. +// peekLocked passes the first count bytes in the pipe to f and returns its +// result. If fewer than count bytes are available, the safemem.BlockSeq passed +// to f will be less than count bytes in length. // -// Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.readLocked(ctx, ops) -} - -func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { +// peekLocked does not mutate the pipe; if the read consumes bytes from the +// pipe, then the caller is responsible for calling p.consumeLocked() and +// p.Notify(waiter.EventOut). (The latter must be called with p.mu unlocked.) +// +// Preconditions: +// * p.mu must be locked. +// * This pipe must have readers. +func (p *Pipe) peekLocked(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if ops.left() == 0 { + if count == 0 { return 0, nil } - // Is the pipe empty? - if p.view.Size() == 0 { - if !p.HasWriters() { - // There are no writers, return EOF. - return 0, io.EOF + // Limit the amount of data read to the amount of data in the pipe. + if count > p.size { + if p.size == 0 { + if !p.HasWriters() { + return 0, io.EOF + } + return 0, syserror.ErrWouldBlock } - return 0, syserror.ErrWouldBlock + count = p.size } - // Limit how much we consume. - if ops.left() > p.view.Size() { - ops.limit(p.view.Size()) - } + // Prepare the view of the data to be read. + bs := p.bufBlockSeq.DropFirst64(uint64(p.off)).TakeFirst64(uint64(count)) - // Copy user data; the read op is responsible for trimming. - done, err := ops.read(&p.view) - return done, err + // Perform the read. + done, err := f(bs) + return int64(done), err } -type writeOps struct { - // left returns the bytes remaining. - left func() int64 - - // limit should limit subsequent writes. - limit func(int64) - - // write should write to the provided buffer. - write func(*buffer.View) (int64, error) -} - -// write writes data from sv into the pipe and returns the number of bytes -// written. If no bytes are written because the pipe is full (or has less than -// atomicIOBytes free capacity), write returns ErrWouldBlock. +// consumeLocked consumes the first n bytes in the pipe, such that they will no +// longer be visible to future reads. // -// Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.writeLocked(ctx, ops) +// Preconditions: +// * p.mu must be locked. +// * The pipe must contain at least n bytes. +func (p *Pipe) consumeLocked(n int64) { + p.off += n + if max := int64(len(p.buf)); p.off >= max { + p.off -= max + } + p.size -= n } -func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { +// writeLocked passes a safemem.BlockSeq representing the first count bytes of +// unused space in the pipe to f and returns the result. If fewer than count +// bytes are free, the safemem.BlockSeq passed to f will be less than count +// bytes in length. If the pipe is full or otherwise cannot accomodate a write +// of any number of bytes up to count, writeLocked returns ErrWouldBlock +// without calling f. +// +// Unlike peekLocked, writeLocked assumes that f returns the number of bytes +// written to the pipe, and increases the number of bytes stored in the pipe +// accordingly. Callers are still responsible for calling +// p.Notify(waiter.EventIn) with p.mu unlocked. +// +// Preconditions: +// * p.mu must be locked. +func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { // Can't write to a pipe with no readers. if !p.HasReaders() { return 0, syscall.EPIPE @@ -251,29 +249,59 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := ops.left() - avail := p.max - p.view.Size() - if wanted > avail { - if wanted <= atomicIOBytes { + avail := p.max - p.size + short := false + if count > avail { + if count <= atomicIOBytes { return 0, syserror.ErrWouldBlock } - ops.limit(avail) + count = avail + short = true } - // Copy user data. - done, err := ops.write(&p.view) - if err != nil { - return done, err + // Ensure that the buffer is big enough. + if newLen, oldCap := p.size+count, int64(len(p.buf)); newLen > oldCap { + // Allocate a new buffer. + newCap := oldCap * 2 + if oldCap == 0 { + newCap = 8 // arbitrary; sending individual integers across pipes is relatively common + } + for newLen > newCap { + newCap *= 2 + } + if newCap > p.max { + newCap = p.max + } + newBuf := make([]byte, newCap) + // Copy the old buffer's contents to the beginning of the new one. + safemem.CopySeq( + safemem.BlockSeqOf(safemem.BlockFromSafeSlice(newBuf)), + p.bufBlockSeq.DropFirst64(uint64(p.off)).TakeFirst64(uint64(p.size))) + // Switch to the new buffer. + p.buf = newBuf + p.bufBlocks[0] = safemem.BlockFromSafeSlice(newBuf) + p.bufBlocks[1] = p.bufBlocks[0] + p.bufBlockSeq = safemem.BlockSeqFromSlice(p.bufBlocks[:]) + p.off = 0 } - if done < avail { - // Non-failure, but short write. - return done, nil + // Prepare the view of the space to be written. + woff := p.off + p.size + if woff >= int64(len(p.buf)) { + woff -= int64(len(p.buf)) } - if done < wanted { - // Partial write due to full pipe. Note that this could also be - // the short write case above, we would expect a second call - // and the write to return zero bytes in this case. + bs := p.bufBlockSeq.DropFirst64(uint64(woff)).TakeFirst64(uint64(count)) + + // Perform the write. + doneU64, err := f(bs) + done := int64(doneU64) + p.size += done + if done < count || err != nil { + return done, err + } + + // If we shortened the write, adjust the returned error appropriately. + if short { return done, syserror.ErrWouldBlock } @@ -324,7 +352,7 @@ func (p *Pipe) HasWriters() bool { // Precondition: mu must be held. func (p *Pipe) rReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasReaders() && p.view.Size() != 0 { + if p.HasReaders() && p.size != 0 { ready |= waiter.EventIn } if !p.HasWriters() && p.hadWriter { @@ -350,7 +378,7 @@ func (p *Pipe) rReadiness() waiter.EventMask { // Precondition: mu must be held. func (p *Pipe) wReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasWriters() && p.view.Size() < p.max { + if p.HasWriters() && p.size < p.max { ready |= waiter.EventOut } if !p.HasReaders() { @@ -383,7 +411,7 @@ func (p *Pipe) queued() int64 { } func (p *Pipe) queuedLocked() int64 { - return p.view.Size() + return p.size } // FifoSize implements fs.FifoSizer.FifoSize. @@ -406,7 +434,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) { } p.mu.Lock() defer p.mu.Unlock() - if size < p.view.Size() { + if size < p.size { return 0, syserror.EBUSY } p.max = size diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index f665920cb..77246edbe 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -21,9 +21,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -44,46 +44,37 @@ func (p *Pipe) Release(context.Context) { // Read reads from the Pipe into dst. func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) { - n, err := p.read(ctx, readOps{ - left: func() int64 { - return dst.NumBytes() - }, - limit: func(l int64) { - dst = dst.TakeFirst64(l) - }, - read: func(view *buffer.View) (int64, error) { - n, err := dst.CopyOutFrom(ctx, view) - dst = dst.DropFirst64(n) - view.TrimFront(n) - return n, err - }, - }) + n, err := dst.CopyOutFrom(ctx, p) if n > 0 { p.Notify(waiter.EventOut) } return n, err } +// ReadToBlocks implements safemem.Reader.ReadToBlocks for Pipe.Read. +func (p *Pipe) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + n, err := p.read(int64(dsts.NumBytes()), func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }, true /* removeFromSrc */) + return uint64(n), err +} + +func (p *Pipe) read(count int64, f func(srcs safemem.BlockSeq) (uint64, error), removeFromSrc bool) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + n, err := p.peekLocked(count, f) + if n > 0 && removeFromSrc { + p.consumeLocked(n) + } + return n, err +} + // WriteTo writes to w from the Pipe. func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) { - ops := readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(view *buffer.View) (int64, error) { - n, err := view.ReadToWriter(w, count) - if !dup { - view.TrimFront(n) - } - count -= n - return n, err - }, - } - n, err := p.read(ctx, ops) - if n > 0 { + n, err := p.read(count, func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.FromIOWriter{w}.WriteFromBlocks(srcs) + }, !dup /* removeFromSrc */) + if n > 0 && !dup { p.Notify(waiter.EventOut) } return n, err @@ -91,39 +82,31 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) // Write writes to the Pipe from src. func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) { - n, err := p.write(ctx, writeOps{ - left: func() int64 { - return src.NumBytes() - }, - limit: func(l int64) { - src = src.TakeFirst64(l) - }, - write: func(view *buffer.View) (int64, error) { - n, err := src.CopyInTo(ctx, view) - src = src.DropFirst64(n) - return n, err - }, - }) + n, err := src.CopyInTo(ctx, p) if n > 0 { p.Notify(waiter.EventIn) } return n, err } +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks for Pipe.Write. +func (p *Pipe) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + n, err := p.write(int64(srcs.NumBytes()), func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }) + return uint64(n), err +} + +func (p *Pipe) write(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + return p.writeLocked(count, f) +} + // ReadFrom reads from r to the Pipe. func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) { - n, err := p.write(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(view *buffer.View) (int64, error) { - n, err := view.WriteFromReader(r, count) - count -= n - return n, err - }, + n, err := p.write(count, func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.FromIOReader{r}.ReadToBlocks(dsts) }) if n > 0 { p.Notify(waiter.EventIn) diff --git a/pkg/sentry/kernel/pipe/save_restore.go b/pkg/sentry/kernel/pipe/save_restore.go new file mode 100644 index 000000000..f135827de --- /dev/null +++ b/pkg/sentry/kernel/pipe/save_restore.go @@ -0,0 +1,26 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "gvisor.dev/gvisor/pkg/safemem" +) + +// afterLoad is called by stateify. +func (p *Pipe) afterLoad() { + p.bufBlocks[0] = safemem.BlockFromSafeSlice(p.buf) + p.bufBlocks[1] = p.bufBlocks[0] + p.bufBlockSeq = safemem.BlockSeqFromSlice(p.bufBlocks[:]) +} diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 2d47d2e82..d5a91730d 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -16,7 +16,6 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -269,12 +268,10 @@ func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { // SpliceToNonPipe performs a splice operation from fd to a non-pipe file. func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescription, off, count int64) (int64, error) { fd.pipe.mu.Lock() - defer fd.pipe.mu.Unlock() // Cap the sequence at number of bytes actually available. - v := fd.pipe.queuedLocked() - if v < count { - count = v + if count > fd.pipe.size { + count = fd.pipe.size } src := usermem.IOSequence{ IO: fd, @@ -291,154 +288,97 @@ func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescripti n, err = out.PWrite(ctx, src, off, vfs.WriteOptions{}) } if n > 0 { - fd.pipe.view.TrimFront(n) + fd.pipe.consumeLocked(n) + } + + fd.pipe.mu.Unlock() + + if n > 0 { + fd.pipe.Notify(waiter.EventOut) } return n, err } // SpliceFromNonPipe performs a splice operation from a non-pipe file to fd. func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescription, off, count int64) (int64, error) { - fd.pipe.mu.Lock() - defer fd.pipe.mu.Unlock() - dst := usermem.IOSequence{ IO: fd, Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), } + var ( + n int64 + err error + ) + fd.pipe.mu.Lock() if off == -1 { - return in.Read(ctx, dst, vfs.ReadOptions{}) + n, err = in.Read(ctx, dst, vfs.ReadOptions{}) + } else { + n, err = in.PRead(ctx, dst, off, vfs.ReadOptions{}) + } + fd.pipe.mu.Unlock() + + if n > 0 { + fd.pipe.Notify(waiter.EventIn) } - return in.PRead(ctx, dst, off, vfs.ReadOptions{}) + return n, err } // CopyIn implements usermem.IO.CopyIn. Note that it is the caller's -// responsibility to trim fd.pipe.view after the read is completed. +// responsibility to call fd.pipe.consumeLocked() and +// fd.pipe.Notify(waiter.EventOut) after the read is completed. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { - origCount := int64(len(dst)) - n, err := fd.pipe.readLocked(ctx, readOps{ - left: func() int64 { - return int64(len(dst)) - }, - limit: func(l int64) { - dst = dst[:l] - }, - read: func(view *buffer.View) (int64, error) { - n, err := view.ReadAt(dst, 0) - return int64(n), err - }, + n, err := fd.pipe.peekLocked(int64(len(dst)), func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), srcs) }) - if n > 0 { - fd.pipe.Notify(waiter.EventOut) - } - if err == nil && n != origCount { - return int(n), syserror.ErrWouldBlock - } return int(n), err } -// CopyOut implements usermem.IO.CopyOut. +// CopyOut implements usermem.IO.CopyOut. Note that it is the caller's +// responsibility to call fd.pipe.Notify(waiter.EventIn) after the +// write is completed. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { - origCount := int64(len(src)) - n, err := fd.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return int64(len(src)) - }, - limit: func(l int64) { - src = src[:l] - }, - write: func(view *buffer.View) (int64, error) { - view.Append(src) - return int64(len(src)), nil - }, + n, err := fd.pipe.writeLocked(int64(len(src)), func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src))) }) - if n > 0 { - fd.pipe.Notify(waiter.EventIn) - } - if err == nil && n != origCount { - return int(n), syserror.ErrWouldBlock - } return int(n), err } // ZeroOut implements usermem.IO.ZeroOut. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { - origCount := toZero - n, err := fd.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return toZero - }, - limit: func(l int64) { - toZero = l - }, - write: func(view *buffer.View) (int64, error) { - view.Grow(view.Size()+toZero, true /* zero */) - return toZero, nil - }, + n, err := fd.pipe.writeLocked(toZero, func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.ZeroSeq(dsts) }) - if n > 0 { - fd.pipe.Notify(waiter.EventIn) - } - if err == nil && n != origCount { - return n, syserror.ErrWouldBlock - } return n, err } // CopyInTo implements usermem.IO.CopyInTo. Note that it is the caller's -// responsibility to trim fd.pipe.view after the read is completed. +// responsibility to call fd.pipe.consumeLocked() and +// fd.pipe.Notify(waiter.EventOut) after the read is completed. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { - count := ars.NumBytes() - if count == 0 { - return 0, nil - } - origCount := count - n, err := fd.pipe.readLocked(ctx, readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(view *buffer.View) (int64, error) { - n, err := view.ReadToSafememWriter(dst, uint64(count)) - return int64(n), err - }, + return fd.pipe.peekLocked(ars.NumBytes(), func(srcs safemem.BlockSeq) (uint64, error) { + return dst.WriteFromBlocks(srcs) }) - if n > 0 { - fd.pipe.Notify(waiter.EventOut) - } - if err == nil && n != origCount { - return n, syserror.ErrWouldBlock - } - return n, err } // CopyOutFrom implements usermem.IO.CopyOutFrom. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { - count := ars.NumBytes() - if count == 0 { - return 0, nil - } - origCount := count - n, err := fd.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(view *buffer.View) (int64, error) { - n, err := view.WriteFromSafememReader(src, uint64(count)) - return int64(n), err - }, + n, err := fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) { + return src.ReadToBlocks(dsts) }) if n > 0 { fd.pipe.Notify(waiter.EventIn) } - if err == nil && n != origCount { - return n, syserror.ErrWouldBlock - } return n, err } @@ -481,37 +421,23 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr } lockTwoPipes(dst.pipe, src.pipe) - defer dst.pipe.mu.Unlock() - defer src.pipe.mu.Unlock() - - n, err := dst.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(dstView *buffer.View) (int64, error) { - return src.pipe.readLocked(ctx, readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(srcView *buffer.View) (int64, error) { - n, err := srcView.ReadToSafememWriter(dstView, uint64(count)) - if n > 0 && removeFromSrc { - srcView.TrimFront(int64(n)) - } - return int64(n), err - }, - }) - }, + n, err := dst.pipe.writeLocked(count, func(dsts safemem.BlockSeq) (uint64, error) { + n, err := src.pipe.peekLocked(int64(dsts.NumBytes()), func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }) + if n > 0 && removeFromSrc { + src.pipe.consumeLocked(n) + } + return uint64(n), err }) + dst.pipe.mu.Unlock() + src.pipe.mu.Unlock() + if n > 0 { dst.pipe.Notify(waiter.EventIn) - src.pipe.Notify(waiter.EventOut) + if removeFromSrc { + src.pipe.Notify(waiter.EventOut) + } } return n, err } diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index b2206900b..22abca120 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -18,7 +18,6 @@ go_library( ], deps = [ "//pkg/abi/linux", - "//pkg/amutex", "//pkg/binary", "//pkg/context", "//pkg/log", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 57f224120..03749a8bf 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -36,7 +36,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -187,6 +186,21 @@ var Metrics = tcpip.Stats{ IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."), IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."), }, + ARP: tcpip.ARPStats{ + PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."), + DisabledPacketsReceived: mustCreateMetric("/netstack/arp/disabled_packets_received", "Number of ARP packets received from the link layer when the ARP layer is disabled."), + MalformedPacketsReceived: mustCreateMetric("/netstack/arp/malformed_packets_received", "Number of ARP packets which failed ARP header validation checks."), + RequestsReceived: mustCreateMetric("/netstack/arp/requests_received", "Number of ARP requests received."), + RequestsReceivedUnknownTargetAddress: mustCreateMetric("/netstack/arp/requests_received_unknown_addr", "Number of ARP requests received with an unknown target address."), + OutgoingRequestInterfaceHasNoLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_iface_has_no_addr", "Number of failed attempts to send an ARP request with an interface that has no network address."), + OutgoingRequestBadLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_invalid_local_addr", "Number of failed attempts to send an ARP request with a provided local address that is invalid."), + OutgoingRequestNetworkUnreachableErrors: mustCreateMetric("/netstack/arp/outgoing_requests_network_unreachable", "Number of failed attempts to send an ARP request with a network unreachable error."), + OutgoingRequestsDropped: mustCreateMetric("/netstack/arp/outgoing_requests_dropped", "Number of ARP requests which failed to write to a link-layer endpoint."), + OutgoingRequestsSent: mustCreateMetric("/netstack/arp/outgoing_requests_sent", "Number of ARP requests sent."), + RepliesReceived: mustCreateMetric("/netstack/arp/replies_received", "Number of ARP replies received."), + OutgoingRepliesDropped: mustCreateMetric("/netstack/arp/outgoing_replies_dropped", "Number of ARP replies which failed to write to a link-layer endpoint."), + OutgoingRepliesSent: mustCreateMetric("/netstack/arp/outgoing_replies_sent", "Number of ARP replies sent."), + }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), @@ -459,18 +473,10 @@ func (i *ioSequencePayload) DropFirst(n int) { // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { f := &ioSequencePayload{ctx: ctx, src: src} - n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } - - if resCh != nil { - if err := amutex.Block(ctx, resCh); err != nil { - return 0, err - } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) - } - if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } @@ -526,24 +532,12 @@ func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { f := &readerPayload{ctx: ctx, r: r, count: count} - n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + n, err := s.Endpoint.Write(f, tcpip.WriteOptions{ // Reads may be destructive but should be very fast, // so we can't release the lock while copying data. Atomic: true, }) if err == tcpip.ErrWouldBlock { - return 0, syserror.ErrWouldBlock - } - - if resCh != nil { - if err := amutex.Block(ctx, resCh); err != nil { - return 0, err - } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ - Atomic: true, // See above. - }) - } - if err == tcpip.ErrWouldBlock { return n, syserror.ErrWouldBlock } else if err != nil { return int64(n), f.err // Propagate error. @@ -2836,13 +2830,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } v := &ioSequencePayload{t, src} - n, resCh, err := s.Endpoint.Write(v, opts) - if resCh != nil { - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err) - } - n, _, err = s.Endpoint.Write(v, opts) - } + n, err := s.Endpoint.Write(v, opts) dontWait := flags&linux.MSG_DONTWAIT != 0 if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. @@ -2861,7 +2849,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b v.DropFirst(int(n)) total := n for { - n, _, err = s.Endpoint.Write(v, opts) + n, err = s.Endpoint.Write(v, opts) v.DropFirst(int(n)) total += n diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index b756bfca0..6f70b02fc 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -16,7 +16,6 @@ package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -131,18 +130,10 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs } f := &ioSequencePayload{ctx: ctx, src: src} - n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } - - if resCh != nil { - if err := amutex.Block(ctx, resCh); err != nil { - return 0, err - } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) - } - if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 1c4cdb0dd..134051124 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,24 +29,23 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) { return 0, syserror.EINVAL } - + if opts.Length == 0 { + return 0, nil + } if opts.Length > int64(kernel.MAX_RW_COUNT) { opts.Length = int64(kernel.MAX_RW_COUNT) } var ( - total int64 n int64 err error inCh chan struct{} outCh chan struct{} ) - for opts.Length > 0 { + for { n, err = fs.Splice(t, outFile, inFile, opts) - opts.Length -= n - total += n - if err != syserror.ErrWouldBlock { + if n != 0 || err != syserror.ErrWouldBlock { break } else if err == syserror.ErrWouldBlock && nonBlocking { break @@ -87,13 +86,13 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB } } - if total > 0 { + if n > 0 { // On Linux, inotify behavior is not very consistent with splice(2). We try // our best to emulate Linux for very basic calls to splice, where for some // reason, events are generated for output files, but not input files. outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0) } - return total, err + return n, err } // Sendfile implements linux system call sendfile(2). diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 2756d4471..cb8981633 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -93,7 +93,6 @@ func init() { addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted) addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile) addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue) - addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown) addErrMapping(tcpip.ErrBadAddress, ErrBadAddress) addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable) addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong) diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 7193f56ad..85a0b8b90 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -397,22 +397,9 @@ func (c *TCPConn) Write(b []byte) (int, error) { } var n int64 - var resCh <-chan struct{} - n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) nbytes += int(n) v.TrimFront(int(n)) - - if resCh != nil { - select { - case <-deadline: - return nbytes, c.newOpError("write", &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } } if err == nil { @@ -666,17 +653,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { v := buffer.NewView(len(b)) copy(v, b) - n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts) - if resCh != nil { - select { - case <-deadline: - return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - } - + n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -689,7 +666,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index dd2e1a125..9f5ee43d7 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -191,8 +191,8 @@ func shuffle(b []int) { } func createFile(t *testing.T, size int64, initQueue bool) int { - tmpDir := os.Getenv("TEST_TMPDIR") - if tmpDir == "" { + tmpDir, ok := os.LookupEnv("TEST_TMPDIR") + if !ok { tmpDir = os.Getenv("TMPDIR") } f, err := ioutil.TempFile(tmpDir, "sharedmem_test") diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 3d5c0d270..3259d052f 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -119,21 +119,28 @@ func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *t } func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats().ARP + stats.PacketsReceived.Increment() + if !e.isEnabled() { + stats.DisabledPacketsReceived.Increment() return } h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { + stats.MalformedPacketsReceived.Increment() return } switch h.Op() { case header.ARPRequest: + stats.RequestsReceived.Increment() localAddr := tcpip.Address(h.ProtocolAddressTarget()) if e.nud == nil { if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + stats.RequestsReceivedUnknownTargetAddress.Increment() return // we have no useful answer, ignore the request } @@ -142,6 +149,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + stats.RequestsReceivedUnknownTargetAddress.Increment() return // we have no useful answer, ignore the request } @@ -177,9 +185,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // // Send the packet to the (new) target hardware address on the same // hardware on which the request was received. - _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) + if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil { + stats.OutgoingRepliesDropped.Increment() + } else { + stats.OutgoingRepliesSent.Increment() + } case header.ARPReply: + stats.RepliesReceived.Increment() addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) @@ -233,6 +246,8 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + stats := p.stack.Stats().ARP + if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } @@ -241,15 +256,18 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot if len(localAddr) == 0 { addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) if err != nil { + stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment() return err } if len(addr.Address) == 0 { + stats.OutgoingRequestNetworkUnreachableErrors.Increment() return tcpip.ErrNetworkUnreachable } localAddr = addr.Address } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + stats.OutgoingRequestBadLocalAddressErrors.Increment() return tcpip.ErrBadLocalAddress } @@ -269,7 +287,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } - return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) + if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + stats.OutgoingRequestsDropped.Increment() + return err + } + stats.OutgoingRequestsSent.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index a25cba513..6b61f57ad 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -240,6 +240,10 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr, remoteAddr} { t.Run(strconv.Itoa(i), func(t *testing.T) { + expectedPacketsReceived := c.s.Stats().ARP.PacketsReceived.Value() + 1 + expectedRequestsReceived := c.s.Stats().ARP.RequestsReceived.Value() + 1 + expectedRepliesSent := c.s.Stats().ARP.OutgoingRepliesSent.Value() + 1 + inject(address) pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { @@ -249,6 +253,9 @@ func TestDirectRequest(t *testing.T) { if !rep.IsValid() { t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } + if got := rep.Op(); got != header.ARPReply { + t.Fatalf("got Op = %d, want = %d", got, header.ARPReply) + } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) } @@ -261,6 +268,16 @@ func TestDirectRequest(t *testing.T) { if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) } + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != expectedPacketsReceived { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedPacketsReceived) + } + if got := c.s.Stats().ARP.RequestsReceived.Value(); got != expectedRequestsReceived { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedRequestsReceived) + } + if got := c.s.Stats().ARP.OutgoingRepliesSent.Value(); got != expectedRepliesSent { + t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, expectedRepliesSent) + } }) } @@ -273,6 +290,84 @@ func TestDirectRequest(t *testing.T) { if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } + if got := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnKnownTargetAddress.Value() = %d, want = 1", got) + } +} + +func TestMalformedPacket(t *testing.T) { + c := newTestContext(t, false) + defer c.cleanup() + + v := make(buffer.View, header.ARPSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + + c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } + if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got) + } +} + +func TestDisabledEndpoint(t *testing.T) { + c := newTestContext(t, false) + defer c.cleanup() + + ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) + if err != nil { + t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err) + } + ep.Disable() + + v := make(buffer.View, header.ARPSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + + c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } + if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got) + } +} + +func TestDirectReply(t *testing.T) { + c := newTestContext(t, false) + defer c.cleanup() + + const senderMAC = "\x01\x02\x03\x04\x05\x06" + const senderIPv4 = "\x0a\x00\x00\x02" + + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPReply) + + copy(h.HardwareAddressSender(), senderMAC) + copy(h.ProtocolAddressSender(), senderIPv4) + copy(h.HardwareAddressTarget(), stackLinkAddr) + copy(h.ProtocolAddressTarget(), stackAddr) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + + c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } + if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } } func TestDirectRequestWithNeighborCache(t *testing.T) { @@ -311,6 +406,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + packetsRecv := c.s.Stats().ARP.PacketsReceived.Value() + requestsRecv := c.s.Stats().ARP.RequestsReceived.Value() + requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() + outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value() + // Inject an incoming ARP request. v := make(buffer.View, header.ARPSize) h := header.ARP(v) @@ -323,6 +423,13 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { Data: v.ToVectorisedView(), })) + if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if !test.isValid { // No packets should be sent after receiving an invalid ARP request. // There is no need to perform a blocking read here, since packets are @@ -330,9 +437,20 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { if pkt, ok := c.linkEP.Read(); ok { t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) } + if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want { + t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want) + } + if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want { + t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) + } + return } + if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want { + t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) + } + // Verify an ARP response was sent. pi, ok := c.linkEP.Read() if !ok { @@ -418,6 +536,8 @@ type testInterface struct { stack.LinkEndpoint nicID tcpip.NICID + + writeErr *tcpip.Error } func (t *testInterface) ID() tcpip.NICID { @@ -441,6 +561,10 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if t.writeErr != nil { + return t.writeErr + } + var r stack.Route r.NetProto = protocol r.ResolveWith(remoteLinkAddr) @@ -458,61 +582,99 @@ func TestLinkAddressRequest(t *testing.T) { localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectedErr *tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress + linkErr *tcpip.Error + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress + expectedRequestsSent uint64 + expectedRequestBadLocalAddressErrors uint64 + expectedRequestNetworkUnreachableErrors uint64 + expectedRequestDroppedErrors uint64 }{ { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Unicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrBadLocalAddress, + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Multicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: tcpip.ErrBadLocalAddress, + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Unicast with no local address available", - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrNetworkUnreachable, + name: "Unicast with no local address available", + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 1, }, { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 1, + }, + { + name: "Link error", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + linkErr: tcpip.ErrInvalidEndpointState, + expectedErr: tcpip.ErrInvalidEndpointState, + expectedRequestDroppedErrors: 1, }, } @@ -543,10 +705,24 @@ func TestLinkAddressRequest(t *testing.T) { // can mock a link address request and observe the packets sent to the // link endpoint even though the stack uses the real NIC to validate the // local address. - if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr { t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) } + if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { + t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) + } + if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) + } + if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors) + } + if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) + } + if test.expectedErr != nil { return } @@ -561,6 +737,9 @@ func TestLinkAddressRequest(t *testing.T) { } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op = %d, want = %d", got, header.ARPRequest) + } if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) } @@ -576,3 +755,22 @@ func TestLinkAddressRequest(t *testing.T) { }) } } + +func TestLinkAddressRequestWithoutNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID) + } + + if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 { + t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e9ff70d04..cc045c7a9 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -64,6 +64,7 @@ const ( var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() +var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -87,6 +88,21 @@ type endpoint struct { } } +// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. +func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // handleControl expects the entire offending packet to be in the packet + // buffer's data field. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + pkt.NICID = e.nic.ID() + pkt.NetworkProtocolNumber = ProtocolNumber + // Use the same control type as an ICMPv4 destination host unreachable error + // since the host is considered unreachable if we cannot resolve the link + // address to the next hop. + e.handleControl(stack.ControlNoRoute, 0, pkt) +} + // NewEndpoint creates a new ipv4 endpoint. func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bbce1ef78..0ec0a0fef 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -645,29 +645,18 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } - for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}) - if resCh != nil { - if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err) - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) + if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { + t.Fatalf("ep.Write(_): %s", err) + } + for _, args := range []routeArgs{ + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, + {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, + } { + routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { + t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) } - <-resCh - continue - } - if err != nil { - t.Fatalf("ep.Write(_) = (_, _, %s)", err) - } - break + }) } for _, args := range []routeArgs{ diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f2018d073..2f82c3d5f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -163,6 +163,7 @@ func getLabel(addr tcpip.Address) uint8 { panic(fmt.Sprintf("should have a label for address = %s", addr)) } +var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -224,6 +225,18 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. +func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // handleControl expects the entire offending packet to be in the packet + // buffer's data field. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + pkt.NICID = e.nic.ID() + pkt.NetworkProtocolNumber = ProtocolNumber + e.handleControl(stack.ControlAddressUnreachable, 0, pkt) +} + // onAddressAssignedLocked handles an address being assigned. // // Precondition: e.mu must be exclusively locked. diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 4777163cd..a7da9dcd9 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -82,7 +82,7 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) { v.CapLength(n) for len(v) > 0 { - n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) if err != nil { fmt.Println("Write failed:", err) return diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 6883045b5..03b2f2d6f 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -83,7 +83,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe got, ch, err := c.get(addr, linkRes, "", nil, nil) if err == tcpip.ErrWouldBlock { if attemptedResolution { - return got, tcpip.ErrNoLinkAddress + return got, tcpip.ErrTimeout } attemptedResolution = true <-ch @@ -253,8 +253,8 @@ func TestCacheResolutionFailed(t *testing.T) { before := atomic.LoadUint32(&requestCount) e.addr.Addr += "2" - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { + t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -269,8 +269,8 @@ func TestCacheResolutionTimeout(t *testing.T) { linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} e := testAddrs[0] - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { + t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4a34805b5..8a946b4fa 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -217,6 +217,16 @@ func (n *NIC) disableLocked() { ep.Disable() } + // Clear the neighbour table (including static entries) as we cannot guarantee + // that the current neighbour table will be valid when the NIC is enabled + // again. + // + // This matches linux's behaviour at the time of writing: + // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 + if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported { + panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) + } + if !n.setEnabled(false) { panic("should have only done work to disable the NIC if it was enabled") } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 664cc6fa0..5f216ca21 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -268,17 +268,6 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { } } -// SourceLinkAddress returns the source link address of the packet. -func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { - link := pk.LinkHeader().View() - - if link.IsEmpty() { - return "" - } - - return header.Ethernet(link).SourceAddress() -} - // Network returns the network header as a header.Network. // // Network should only be called when NetworkHeader has been set. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 4a3adcf33..bded8814e 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -101,10 +101,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } for _, p := range packets { - if cancelled { - p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if p.route.IsResolutionRequired() { + if cancelled || p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() + + if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { + linkResolvableEP.HandleLinkResolutionFailure(pkt) + } } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 4795208b4..924790779 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -55,7 +55,19 @@ type ControlType int // The following are the allowed values for ControlType values. // TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlNetworkUnreachable ControlType = iota + // ControlAddressUnreachable indicates that an IPv6 packet did not reach its + // destination as the destination address was unreachable. + // + // This maps to the ICMPv6 Destination Ureachable Code 3 error; see + // RFC 4443 section 3.1 for more details. + ControlAddressUnreachable ControlType = iota + ControlNetworkUnreachable + // ControlNoRoute indicates that an IPv4 packet did not reach its destination + // because the destination host was unreachable. + // + // This maps to the ICMPv4 Destination Ureachable Code 1 error; see + // RFC 791's Destination Unreachable Message section (page 4) for more + // details. ControlNoRoute ControlPacketTooBig ControlPortUnreachable @@ -503,6 +515,13 @@ type NetworkInterface interface { WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } +// LinkResolvableNetworkEndpoint handles link resolution events. +type LinkResolvableNetworkEndpoint interface { + // HandleLinkResolutionFailure is called when link resolution prevents the + // argument from having been sent. + HandleLinkResolutionFailure(*PacketBuffer) +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 856ebf6d4..4a3f937e3 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4305,3 +4305,55 @@ func TestWritePacketToRemote(t *testing.T) { } }) } + +func TestClearNeighborCacheOnNICDisable(t *testing.T) { + const ( + nicID = 1 + + ipv4Addr = tcpip.Address("\x01\x02\x03\x04") + ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04") + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, 0, "") + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err) + } + if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err) + } + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 2 { + t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors) + } + + // Disabling the NIC should clear the neighbor table. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 0 { + t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + + // Enabling the NIC should have an empty neighbor table. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 0 { + t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 0ff32c6ea..a2ab7537c 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -90,14 +90,14 @@ func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.Rea return tcpip.ReadResult{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } v, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, @@ -105,10 +105,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions }) _ = pkt.TransportHeader().Push(fakeTransHeaderLen) if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { - return 0, nil, err + return 0, err } - return int64(len(v)), nil, nil + return int64(len(v)), nil } // SetSockOpt sets a socket option. Currently not supported. @@ -222,7 +222,6 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * if err != nil { return } - route.ResolveWith(pkt.SourceLinkAddress()) ep := &fakeTransportEndpoint{ TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -522,8 +521,7 @@ func TestTransportSend(t *testing.T) { // Create buffer that will hold the payload. view := buffer.NewView(30) - _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { + if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f798056c0..49d4912ad 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -104,7 +104,6 @@ var ( ErrConnectionAborted = &Error{msg: "connection aborted"} ErrNoSuchFile = &Error{msg: "no such file"} ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} - ErrNoLinkAddress = &Error{msg: "no remote link address"} ErrBadAddress = &Error{msg: "bad address"} ErrNetworkUnreachable = &Error{msg: "network is unreachable"} ErrMessageTooLong = &Error{msg: "message too long"} @@ -154,7 +153,6 @@ func StringToError(s string) *Error { ErrConnectionAborted, ErrNoSuchFile, ErrInvalidOptionValue, - ErrNoLinkAddress, ErrBadAddress, ErrNetworkUnreachable, ErrMessageTooLong, @@ -640,12 +638,7 @@ type Endpoint interface { // stream (TCP) Endpoints may return partial writes, and even then only // in the case where writing additional data would block. Other Endpoints // will either write the entire message or return an error. - // - // For UDP and Ping sockets if address resolution is required, - // ErrNoLinkAddress and a notification channel is returned for the caller to - // block. Channel is closed once address resolution is complete (success or - // not). The channel is only non-nil in this case. - Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -1598,6 +1591,59 @@ type IPStats struct { OptionUnknownReceived *StatCounter } +// ARPStats collects ARP-specific stats. +type ARPStats struct { + // PacketsReceived is the number of ARP packets received from the link layer. + PacketsReceived *StatCounter + + // DisabledPacketsReceived is the number of ARP packets received from the link + // layer when the ARP layer is disabled. + DisabledPacketsReceived *StatCounter + + // MalformedPacketsReceived is the number of ARP packets that were dropped due + // to being malformed. + MalformedPacketsReceived *StatCounter + + // RequestsReceived is the number of ARP requests received. + RequestsReceived *StatCounter + + // RequestsReceivedUnknownTargetAddress is the number of ARP requests that + // were targeted to an interface different from the one it was received on. + RequestsReceivedUnknownTargetAddress *StatCounter + + // OutgoingRequestInterfaceHasNoLocalAddressErrors is the number of failures + // to send an ARP request because the interface has no network address + // assigned to it. + OutgoingRequestInterfaceHasNoLocalAddressErrors *StatCounter + + // OutgoingRequestBadLocalAddressErrors is the number of failures to send an + // ARP request with a bad local address. + OutgoingRequestBadLocalAddressErrors *StatCounter + + // OutgoingRequestNetworkUnreachableErrors is the number of failures to send + // an ARP request with a network unreachable error. + OutgoingRequestNetworkUnreachableErrors *StatCounter + + // OutgoingRequestsDropped is the number of ARP requests which failed to write + // to a link-layer endpoint. + OutgoingRequestsDropped *StatCounter + + // OutgoingRequestSent is the number of ARP requests successfully written to a + // link-layer endpoint. + OutgoingRequestsSent *StatCounter + + // RepliesReceived is the number of ARP replies received. + RepliesReceived *StatCounter + + // OutgoingRepliesDropped is the number of ARP replies which failed to write + // to a link-layer endpoint. + OutgoingRepliesDropped *StatCounter + + // OutgoingRepliesSent is the number of ARP replies successfully written to a + // link-layer endpoint. + OutgoingRepliesSent *StatCounter +} + // TCPStats collects TCP-specific stats. type TCPStats struct { // ActiveConnectionOpenings is the number of connections opened @@ -1750,6 +1796,9 @@ type Stats struct { // IP breaks out IP-specific stats (both v4 and v6). IP IPStats + // ARP breaks out ARP-specific stats. + ARP ARPStats + // TCP breaks out TCP-specific stats. TCP TCPStats @@ -1784,9 +1833,6 @@ type SendErrors struct { // NoRoute is the number of times we failed to resolve IP route. NoRoute StatCounter - - // NoLinkAddr is the number of times we failed to resolve ARP. - NoLinkAddr StatCounter } // ReadErrors collects segment read errors from an endpoint read call. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ca1e88e99..1742a178d 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -31,5 +31,6 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 4c2084d19..49acd504e 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -157,13 +158,13 @@ func TestForwarding(t *testing.T) { tests := []struct { name string - epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses }{ { name: "IPv4 host1 server with host2 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: host1IPv4Addr.AddressWithPrefix.Address, @@ -177,9 +178,9 @@ func TestForwarding(t *testing.T) { }, { name: "IPv6 host2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: host2IPv6Addr.AddressWithPrefix.Address, @@ -193,9 +194,9 @@ func TestForwarding(t *testing.T) { }, { name: "IPv4 host2 server with routerNIC1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: host2IPv4Addr.AddressWithPrefix.Address, @@ -209,9 +210,9 @@ func TestForwarding(t *testing.T) { }, { name: "IPv6 routerNIC2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address, @@ -225,202 +226,270 @@ func TestForwarding(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - } - - host1Stack := stack.New(stackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) - - epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack) - defer epsAndAddrs.serverEP.Close() - defer epsAndAddrs.clientEP.Close() - - serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} - if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { - t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) - } - clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} - if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { - t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) - } - - write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) { + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr *tcpip.Error + setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() - dataPayload := tcpip.SlicePayload(data) - wOpts := tcpip.WriteOptions{To: to} - n, ch, err := ep.Write(dataPayload, wOpts) - if err == tcpip.ErrNoLinkAddress { - // Wait for link resolution to complete. - <-ch - n, _, err = ep.Write(dataPayload, wOpts) - } - if err != nil { - t.Fatalf("ep.Write(_, _): %s", err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) } - } - - data := []byte{1, 2, 3, 4} - write(epsAndAddrs.clientEP, data, &serverAddr) - - read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress { + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: tcpip.ErrConnectStarted, + setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() - // Wait for the endpoint to be readable. - <-ch - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, len(data), opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) } - - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(data), - Total: len(data), - RemoteAddr: tcpip.FullAddress{Addr: expectedFrom}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes(), data); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - if t.Failed() { - t.FailNow() + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if err == tcpip.ErrWouldBlock { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + return newEP, newCH } + }, + needRemoteAddr: false, + }, + } - return res.RemoteAddr - } - - addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr) - // Unspecify the NIC since NIC IDs are meaningless across stacks. - addr.NIC = 0 - - data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) - write(epsAndAddrs.serverEP, data, &addr) - addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr) - if addr.Port != listenPort { - t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + defer epsAndAddrs.serverEP.Close() + defer epsAndAddrs.clientEP.Close() + + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr { + t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr) + } + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + clientAddr = addr + clientAddr.NIC = 0 + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + dataPayload := tcpip.SlicePayload(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(dataPayload, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + // Wait for the endpoint to be readable. + <-ch + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err := ep.Read(&buf, len(data), opts) + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + read(serverCH, serverEP, data, clientAddr) + + data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + }) } }) } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index b4bffaec1..ed00c90d4 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -29,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,6 +56,13 @@ var ( PrefixLen: 8, }, } + ipv4Addr3 = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.3").To4()), + PrefixLen: 8, + }, + } ipv6Addr1 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -68,8 +77,65 @@ var ( PrefixLen: 64, }, } + ipv6Addr3 = tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::3").To16()), + PrefixLen: 64, + }, + } ) +func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) { + host1Stack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr1.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: ipv6Addr1.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr2.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: ipv6Addr2.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + }) + + return host1Stack, host2Stack +} + // TestPing tests that two hosts can ping eachother when link resolution is // enabled. func TestPing(t *testing.T) { @@ -128,51 +194,7 @@ func TestPing(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, } - host1Stack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - - host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: ipv4Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: ipv6Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: ipv4Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: ipv6Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - }) + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) var wq waiter.Queue we, waiterCH := waiter.NewChannelEntry(nil) @@ -183,19 +205,12 @@ func TestPing(t *testing.T) { } defer ep.Close() - // The first write should trigger link resolution. icmpBuf := test.icmpBuf(t) wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress) - } else { - // Wait for link resolution to complete. - <-ch - } - if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { t.Fatalf("ep.Write(_, _): %s", err) } else if want := int64(len(icmpBuf)); n != want { - t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) } // Wait for the endpoint to be readable. @@ -224,3 +239,159 @@ func TestPing(t *testing.T) { }) } } + +func TestTCPLinkResolutionFailure(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedWriteErr *tcpip.Error + sockError tcpip.SockError + }{ + { + name: "IPv4 with resolvable remote", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv6 with resolvable remote", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv4 without resolvable remote", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedWriteErr: tcpip.ErrNoRoute, + sockError: tcpip.SockError{ + Err: tcpip.ErrNoRoute, + ErrType: byte(header.ICMPv4DstUnreachable), + ErrCode: byte(header.ICMPv4HostUnreachable), + ErrOrigin: tcpip.SockExtErrorOriginICMP, + Dst: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv4Addr3.AddressWithPrefix.Address, + Port: 1234, + }, + Offender: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv4Addr1.AddressWithPrefix.Address, + }, + NetProto: ipv4.ProtocolNumber, + }, + }, + { + name: "IPv6 without resolvable remote", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedWriteErr: tcpip.ErrNoRoute, + sockError: tcpip.SockError{ + Err: tcpip.ErrNoRoute, + ErrType: byte(header.ICMPv6DstUnreachable), + ErrCode: byte(header.ICMPv6AddressUnreachable), + ErrOrigin: tcpip.SockExtErrorOriginICMP6, + Dst: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv6Addr3.AddressWithPrefix.Address, + Port: 1234, + }, + Offender: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv6Addr1.AddressWithPrefix.Address, + }, + NetProto: ipv6.ProtocolNumber, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + } + + host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) + + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) + } + defer listenerEP.Close() + + listenerAddr := tcpip.FullAddress{Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + + var clientWQ waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&we, waiter.EventOut|waiter.EventErr) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) + } + defer clientEP.Close() + + sockOpts := clientEP.SocketOptions() + sockOpts.SetRecvError(true) + + remoteAddr := listenerAddr + remoteAddr.Addr = test.remoteAddr + if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted) + } + + // Wait for an error due to link resolution failing, or the endpoint to be + // writable. + <-ch + var wOpts tcpip.WriteOptions + if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr { + t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + } + + if test.expectedWriteErr == nil { + return + } + + sockErr := sockOpts.DequeueErr() + if sockErr == nil { + t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil") + } + + sockErrCmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported(tcpip.SockError{}), + cmp.Comparer(func(a, b *tcpip.Error) bool { + // tcpip.Error holds an unexported field but the errors netstack uses + // are pre defined so we can simply compare pointers. + return a == b + }), + // Ignore the payload since we do not know the TCP seq/ack numbers. + checker.IgnoreCmpPath( + "Payload", + ), + } + + if addr, err := clientEP.GetLocalAddress(); err != nil { + t.Fatalf("clientEP.GetLocalAddress(): %s", err) + } else { + test.sockError.Offender.Port = addr.Port + } + if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { + t.Errorf("socket error mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index cb6169cfc..a59f25cc3 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -232,12 +232,12 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { Port: localPort, }, } - n, _, err := sep.Write(tcpip.SlicePayload(data), wopts) + n, err := sep.Write(tcpip.SlicePayload(data), wopts) if err != nil { t.Fatalf("sep.Write(_, _): %s", err) } if want := int64(len(data)); n != want { - t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) + t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) } var buf bytes.Buffer diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index b42375695..eabc87938 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -587,10 +587,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, } data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) - if n, _, err := wep.ep.Write(data, writeOpts); err != nil { + if n, err := wep.ep.Write(data, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { - t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) + t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) } for j, rep := range eps { diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 52cf89b54..76f7f54c6 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -197,10 +197,10 @@ func TestLocalPing(t *testing.T) { payload := tcpip.SlicePayload(test.icmpBuf(t)) var wOpts tcpip.WriteOptions - if n, _, err := ep.Write(payload, wOpts); err != nil { + if n, err := ep.Write(payload, wOpts); err != nil { t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) } // Wait for the endpoint to become readable. @@ -335,14 +335,14 @@ func TestLocalUDP(t *testing.T) { wOpts := tcpip.WriteOptions{ To: &serverAddr, } - if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { - t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) + if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) } else if subTest.expectedWriteErr != nil { // Nothing else to test if we expected not to be able to send the // UDP packet. return } else if n != int64(len(clientPayload)) { - t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload)) + t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload)) } } @@ -382,10 +382,10 @@ func TestLocalUDP(t *testing.T) { wOpts := tcpip.WriteOptions{ To: &clientAddr, } - if n, _, err := server.Write(serverPayload, wOpts); err != nil { + if n, err := server.Write(serverPayload, wOpts); err != nil { t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) } else if n != int64(len(serverPayload)) { - t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload)) + t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index c32fe5c4f..87277fbd3 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -236,8 +236,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - n, ch, err := e.write(p, opts) +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { + n, err := e.write(p, opts) switch err { case nil: e.stats.PacketsSent.Increment() @@ -247,8 +247,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.stats.WriteErrors.WriteClosed.Increment() case tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoLinkAddress: - e.stats.SendErrors.NoLinkAddr.Increment() case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() @@ -256,13 +254,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // For all other errors when writing to the network layer. e.stats.SendErrors.SendToNetworkFailed.Increment() } - return n, ch, err + return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } to := opts.To @@ -272,14 +270,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } // Prepare for write. for { retry, err := e.prepareForWrite(to) if err != nil { - return 0, nil, err + return 0, err } if !retry { @@ -294,7 +292,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } nicID = e.BindNICID @@ -302,31 +300,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { - return 0, nil, err + return 0, err } // Find the endpoint. r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { - return 0, nil, err + return 0, err } defer r.Release() route = r } - if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - return 0, ch, tcpip.ErrNoLinkAddress - } - return 0, nil, err - } - } - v, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } switch e.NetProto { @@ -338,10 +327,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if err != nil { - return 0, nil, err + return 0, err } - return int64(len(v)), nil, nil + return int64(len(v)), nil } // SetSockOpt sets a socket option. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 3ab060751..c3b3b8d34 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -207,9 +207,9 @@ func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpi return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index dd260535f..425bcf3ee 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -234,20 +234,20 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip } // Write implements tcpip.Endpoint.Write. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // We can create, but not write to, unassociated IPv6 endpoints. if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } } - n, ch, err := e.write(p, opts) + n, err := e.write(p, opts) switch err { case nil: e.stats.PacketsSent.Increment() @@ -257,8 +257,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.stats.WriteErrors.WriteClosed.Increment() case tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoLinkAddress: - e.stats.SendErrors.NoLinkAddr.Increment() case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() @@ -266,25 +264,25 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // For all other errors when writing to the network layer. e.stats.SendErrors.SendToNetworkFailed.Increment() } - return n, ch, err + return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } e.mu.RLock() defer e.mu.RUnlock() if e.closed { - return 0, nil, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } // If this is an unassociated socket and callee provided a nonzero @@ -292,7 +290,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() // Update dstAddr with the address in the IP header, unless @@ -313,7 +311,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - return 0, nil, tcpip.ErrDestinationRequired + return 0, tcpip.ErrDestinationRequired } return e.finishWrite(payloadBytes, e.route) @@ -323,42 +321,30 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - return 0, nil, err + return 0, err } - n, ch, err := e.finishWrite(payloadBytes, route) + n, err := e.finishWrite(payloadBytes, route) route.Release() - return n, ch, err + return n, err } // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { - // We may need to resolve the route (match a link layer address to the - // network address). If that requires blocking (e.g. to use ARP), - // return a channel on which the caller can wait. - if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - return 0, ch, tcpip.ErrNoLinkAddress - } - return 0, nil, err - } - } - +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), }) if err := route.WriteHeaderIncludedPacket(pkt); err != nil { - return 0, nil, err + return 0, err } } else { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -371,11 +357,11 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, }, pkt); err != nil { - return 0, nil, err + return 0, err } } - return int64(len(payloadBytes)), nil, nil + return int64(len(payloadBytes)), nil } // Disconnect implements tcpip.Endpoint.Disconnect. diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2d96a65bd..6921de0f1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -210,7 +210,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i if err != nil { return nil, err } - route.ResolveWith(s.remoteLinkAddr) n := newEndpoint(l.stack, netProto, queue) n.ops.SetV6Only(l.v6Only) @@ -306,10 +305,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Initialize and start the handshake. h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) - if err := h.start(); err != nil { - l.cleanupFailedHandshake(h) - return nil, err - } + h.start() return h, nil } @@ -573,7 +569,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er return err } defer route.Release() - route.ResolveWith(s.remoteLinkAddr) // Send SYN without window scaling because we currently // don't encode this information in the cookie. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a00ef97c6..6cdbb8bee 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -53,7 +53,6 @@ const ( wakerForNotification = iota wakerForNewSegment wakerForResend - wakerForResolution ) const ( @@ -460,66 +459,9 @@ func (h *handshake) processSegments() *tcpip.Error { return nil } -func (h *handshake) resolveRoute() *tcpip.Error { - // Set up the wakers. - var s sleep.Sleeper - resolutionWaker := &sleep.Waker{} - s.AddWaker(resolutionWaker, wakerForResolution) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - defer s.Done() - - // Initial action is to resolve route. - index := wakerForResolution - attemptedResolution := false - for { - switch index { - case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { - if err != nil { - h.ep.stats.SendErrors.NoRoute.Increment() - } - // Either success (err == nil) or failure. - return err - } - if attemptedResolution { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - return tcpip.ErrNoLinkAddress - } - attemptedResolution = true - // Resolution not completed. Keep trying... - - case wakerForNotification: - n := h.ep.fetchNotifications() - if n¬ifyClose != 0 { - return tcpip.ErrAborted - } - if n¬ifyDrain != 0 { - close(h.ep.drainDone) - h.ep.mu.Unlock() - <-h.ep.undrain - h.ep.mu.Lock() - } - if n¬ifyError != 0 { - return h.ep.lastErrorLocked() - } - } - - // Wait for notification. - h.ep.mu.Unlock() - index, _ = s.Fetch(true /* block */) - h.ep.mu.Lock() - } -} - -// start resolves the route if necessary and sends the first -// SYN/SYN-ACK. -func (h *handshake) start() *tcpip.Error { - if h.ep.route.IsResolutionRequired() { - if err := h.resolveRoute(); err != nil { - return err - } - } - +// start sends the first SYN/SYN-ACK. It does not block, even if link address +// resolution is required. +func (h *handshake) start() { h.startTime = time.Now() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled @@ -560,7 +502,6 @@ func (h *handshake) start() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) - return nil } // complete completes the TCP 3-way handshake initiated by h.start(). diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 25b180fa5..a4508e871 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1507,7 +1507,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -1520,7 +1520,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.sndBufMu.Unlock() e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err + return 0, err } // We can release locks while copying data. @@ -1541,7 +1541,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.sndBufMu.Unlock() e.UnlockUser() } - return 0, nil, perr + return 0, perr } if !opts.Atomic { @@ -1555,7 +1555,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.sndBufMu.Unlock() e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err + return 0, err } // Discard any excess data copied in due to avail being reduced due @@ -1575,7 +1575,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Do the work inline. e.handleWrite() e.UnlockUser() - return int64(len(v)), nil, nil + return int64(len(v)), nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -2325,68 +2325,17 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if run { - if err := e.startMainLoop(handshake); err != nil { - return err - } - } - - return tcpip.ErrConnectStarted -} - -// startMainLoop sends the initial SYN and starts the main loop for the -// endpoint. -func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { - preloop := func() *tcpip.Error { if handshake { h := e.newHandshake() e.setEndpointState(StateSynSent) - if err := h.start(); err != nil { - e.lastErrorMu.Lock() - e.lastError = err - e.lastErrorMu.Unlock() - - e.setEndpointState(StateError) - e.hardError = err - - // Call cleanupLocked to free up any reservations. - e.cleanupLocked() - return err - } + h.start() } e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - return nil - } - - if e.route.IsResolutionRequired() { - // If the endpoint is closed between releasing e.mu and the goroutine below - // acquiring it, make sure that cleanup is deferred to the new goroutine. e.workerRunning = true - - // Sending the initial SYN may block due to route resolution; do it in a - // separate goroutine to avoid blocking the syscall goroutine. - go func() { // S/R-SAFE: will be drained before save. - e.mu.Lock() - if err := preloop(); err != nil { - e.workerRunning = false - e.mu.Unlock() - return - } - e.mu.Unlock() - _ = e.protocolMainLoop(handshake, nil) - }() - return nil + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } - // No route resolution is required, so we can send the initial SYN here without - // blocking. This will hopefully reduce overall latency by overlapping time - // spent waiting for a SYN-ACK and time spent spinning up a new goroutine - // for the main loop. - if err := preloop(); err != nil { - return err - } - e.workerRunning = true - go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. - return nil + return tcpip.ErrConnectStarted } // ConnectEndpoint is not supported. @@ -2779,6 +2728,9 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt case stack.ControlNoRoute: e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) + case stack.ControlAddressUnreachable: + e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) + case stack.ControlNetworkUnreachable: e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c9e194f82..1720370c9 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -222,7 +222,6 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error return err } defer route.Release() - route.ResolveWith(s.remoteLinkAddr) // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c5a6d2fba..7cca4def5 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -49,11 +49,10 @@ type segment struct { // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of // individual members for link/network packet info. - srcAddr tcpip.Address - dstAddr tcpip.Address - netProto tcpip.NetworkProtocolNumber - nicID tcpip.NICID - remoteLinkAddr tcpip.LinkAddress + srcAddr tcpip.Address + dstAddr tcpip.Address + netProto tcpip.NetworkProtocolNumber + nicID tcpip.NICID data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -89,13 +88,12 @@ type segment struct { func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { netHdr := pkt.Network() s := &segment{ - refCnt: 1, - id: id, - srcAddr: netHdr.SourceAddress(), - dstAddr: netHdr.DestinationAddress(), - netProto: pkt.NetworkProtocolNumber, - nicID: pkt.NICID, - remoteLinkAddr: pkt.SourceLinkAddress(), + refCnt: 1, + id: id, + srcAddr: netHdr.SourceAddress(), + dstAddr: netHdr.DestinationAddress(), + netProto: pkt.NetworkProtocolNumber, + nicID: pkt.NICID, } s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) @@ -128,7 +126,6 @@ func (s *segment) clone() *segment { window: s.window, netProto: s.netProto, nicID: s.nicID, - remoteLinkAddr: s.remoteLinkAddr, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index b9993ce1a..f7aaee23f 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -49,7 +49,7 @@ func TestFastRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -214,7 +214,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -256,7 +256,7 @@ func TestCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -361,7 +361,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -470,11 +470,11 @@ func TestRetransmit(t *testing.T) { // Write all the data in two shots. Packets will only be written at the // MTU size though. half := data[:len(data)/2] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } half = data[len(data)/2:] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 9818ffa0f..342eb5eb8 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -68,7 +68,7 @@ func TestRACKUpdate(t *testing.T) { // Write the data. xmitTime = time.Now() - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -120,7 +120,7 @@ func TestRACKDetectReorder(t *testing.T) { } // Write the data. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -151,7 +151,7 @@ func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.Vie } // Write the data. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index faf0c0ad7..6635bb815 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -402,7 +402,7 @@ func TestSACKRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index aeceee7e0..729bf7ef5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1348,7 +1348,7 @@ func TestTOSV4(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1397,7 +1397,7 @@ func TestTrafficClassV6(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1977,7 +1977,11 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Keep the payload size < segment overhead and such that it is a multiple // of the window scaled value. This enables the test to perform equality // checks on the incoming receive window. - payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale)) + payloadSize := 1 << c.RcvdWindowScale + if payloadSize >= tcp.SegSize { + t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize) + } + payload := generateRandomPayload(t, payloadSize) payloadLen := seqnum.Size(len(payload)) iss := seqnum.Value(789) seqNum := iss.Add(1) @@ -2173,7 +2177,7 @@ func TestSimpleSend(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2214,8 +2218,7 @@ func TestZeroWindowSend(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2283,7 +2286,7 @@ func TestScaledWindowConnect(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2315,7 +2318,7 @@ func TestNonScaledWindowConnect(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2389,7 +2392,7 @@ func TestScaledWindowAccept(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2463,7 +2466,7 @@ func TestNonScaledWindowAccept(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2626,7 +2629,7 @@ func TestSegmentMerging(t *testing.T) { // anymore packets from going out. for i := 0; i < tcp.InitialCwnd; i++ { view := buffer.NewViewFromBytes([]byte{0}) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2637,7 +2640,7 @@ func TestSegmentMerging(t *testing.T) { for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2707,7 +2710,7 @@ func TestDelay(t *testing.T) { for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2754,7 +2757,7 @@ func TestUndelay(t *testing.T) { allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2838,7 +2841,7 @@ func TestMSSNotDelayed(t *testing.T) { allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} for i, data := range allData { view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2889,7 +2892,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3321,7 +3324,7 @@ func TestSendOnResetConnection(t *testing.T) { // Try to write. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) } } @@ -3344,7 +3347,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) - _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3401,7 +3404,7 @@ func TestMaxRTO(t *testing.T) { c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3450,7 +3453,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) } - if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } pkt := c.GetPacket() @@ -3588,7 +3591,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Write something out, and have it acknowledged. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3661,7 +3664,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // any of them. view := buffer.NewView(10) for i := tcp.InitialCwnd; i > 0; i-- { - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } } @@ -3747,7 +3750,7 @@ func TestFinWithPendingData(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3773,7 +3776,7 @@ func TestFinWithPendingData(t *testing.T) { }) // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3834,7 +3837,7 @@ func TestFinWithPartialAck(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3871,7 +3874,7 @@ func TestFinWithPartialAck(t *testing.T) { ) // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3978,7 +3981,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { // Send some data. Check that it's capped by the window size. view := buffer.NewView(65535) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4035,9 +4038,6 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) } - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0) - } } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { @@ -4607,7 +4607,7 @@ func TestSelfConnect(t *testing.T) { data := []byte{1, 2, 3} view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4785,7 +4785,7 @@ func TestPathMTUDiscovery(t *testing.T) { data[i] = byte(i) } - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5074,7 +5074,7 @@ func TestKeepalive(t *testing.T) { // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5903,9 +5903,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - - if err != nil { + if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7103,7 +7101,7 @@ func TestTCPCloseWithData(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7202,7 +7200,7 @@ func TestTCPUserTimeout(t *testing.T) { // Send some data and wait before ACKing it. view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 9e02d467d..88fb054bb 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -154,7 +154,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } @@ -217,7 +217,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5d87f3a7e..520a0ac9d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -417,8 +417,8 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - n, ch, err := e.write(p, opts) +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { + n, err := e.write(p, opts) switch err { case nil: e.stats.PacketsSent.Increment() @@ -428,8 +428,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.stats.WriteErrors.WriteClosed.Increment() case tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoLinkAddress: - e.stats.SendErrors.NoLinkAddr.Increment() case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() @@ -437,17 +435,17 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // For all other errors when writing to the network layer. e.stats.SendErrors.SendToNetworkFailed.Increment() } - return n, ch, err + return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { if err := e.LastError(); err != nil { - return 0, nil, err + return 0, err } // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } to := opts.To @@ -463,14 +461,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } // Prepare for write. for { retry, err := e.prepareForWrite(to) if err != nil { - return 0, nil, err + return 0, err } if !retry { @@ -486,7 +484,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } nicID = e.BindNICID @@ -494,17 +492,17 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if to.Port == 0 { // Port 0 is an invalid port to send to. - return 0, nil, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { - return 0, nil, err + return 0, err } r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { - return 0, nil, err + return 0, err } defer r.Release() @@ -513,21 +511,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, nil, tcpip.ErrBroadcastDisabled - } - - if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - return 0, ch, tcpip.ErrNoLinkAddress - } - return 0, nil, err - } + return 0, tcpip.ErrBroadcastDisabled } v, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. @@ -545,7 +534,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c v, ) } - return 0, nil, tcpip.ErrMessageTooLong + return 0, tcpip.ErrMessageTooLong } ttl := e.ttl @@ -575,9 +564,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { - return 0, nil, err + return 0, err } - return int64(len(v)), nil, nil + return int64(len(v)), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index d7fc21f11..49e673d58 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -75,7 +75,6 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, if err != nil { return nil, err } - route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c8da173f1..52403ed78 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -967,7 +967,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) payload := buffer.View(newPayload()) - _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) c.checkEndpointWriteStats(1, epstats, gotErr) @@ -1008,7 +1008,7 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View } } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("Write failed: %s", err) } @@ -1184,7 +1184,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) } @@ -2317,8 +2317,6 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE want.WriteErrors.WriteClosed.IncrementBy(incr) case tcpip.ErrInvalidEndpointState: want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoLinkAddress: - want.SendErrors.NoLinkAddr.IncrementBy(incr) case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: want.SendErrors.NoRoute.IncrementBy(incr) default: @@ -2510,20 +2508,20 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { expectedErrWithoutBcastOpt = nil } - if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } ep.SocketOptions().SetBroadcast(true) - if n, _, err := ep.Write(data, opts); err != nil { - t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) + if n, err := ep.Write(data, opts); err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) } ep.SocketOptions().SetBroadcast(false) - if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } }) } diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index fdd416b5e..a35c7ffa6 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -83,11 +83,10 @@ func ConfigureExePath() error { // TmpDir returns the absolute path to a writable directory that can be used as // scratch by the test. func TmpDir() string { - dir := os.Getenv("TEST_TMPDIR") - if dir == "" { - dir = "/tmp" + if dir, ok := os.LookupEnv("TEST_TMPDIR"); ok { + return dir } - return dir + return "/tmp" } // Logger is a simple logging wrapper. @@ -111,6 +110,30 @@ func (d DefaultLogger) Logf(fmt string, args ...interface{}) { log.Printf(fmt, args...) } +// multiLogger logs to multiple Loggers. +type multiLogger []Logger + +// Name implements Logger.Name. +func (m multiLogger) Name() string { + names := make([]string, len(m)) + for i, l := range m { + names[i] = l.Name() + } + return strings.Join(names, "+") +} + +// Logf implements Logger.Logf. +func (m multiLogger) Logf(fmt string, args ...interface{}) { + for _, l := range m { + l.Logf(fmt, args...) + } +} + +// NewMultiLogger returns a new Logger that logs on multiple Loggers. +func NewMultiLogger(loggers ...Logger) Logger { + return multiLogger(loggers) +} + // Cmd is a simple wrapper. type Cmd struct { logger Logger @@ -519,7 +542,7 @@ func IsStatic(filename string) (bool, error) { // // See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner. func TouchShardStatusFile() error { - if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" { + if statusFile, ok := os.LookupEnv("TEST_SHARD_STATUS_FILE"); ok { cmd := exec.Command("touch", statusFile) if b, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error()) @@ -541,8 +564,9 @@ func TestIndicesForShard(numTests int) ([]int, error) { shardTotal = 1 ) - indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS") - if indexStr != "" && totalStr != "" { + indexStr, indexOk := os.LookupEnv("TEST_SHARD_INDEX") + totalStr, totalOk := os.LookupEnv("TEST_TOTAL_SHARDS") + if indexOk && totalOk { // Parse index and total to ints. var err error shardIndex, err = strconv.Atoi(indexStr) |