From 863e11ac4d6a49787cd5e5f6fe1cd771d0ceb100 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 29 Aug 2019 14:29:43 -0700 Subject: Implement /proc/net/udp. PiperOrigin-RevId: 266229756 --- pkg/sentry/socket/epsocket/epsocket.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 635042263..def29646e 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -27,12 +27,14 @@ package epsocket import ( "bytes" "math" + "reflect" "sync" "syscall" "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -52,6 +54,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -2421,7 +2424,8 @@ func (s *SocketOperations) State() uint32 { return 0 } - if !s.isPacketBased() { + switch { + case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -2450,9 +2454,26 @@ func (s *SocketOperations) State() uint32 { // Internal or unknown state. return 0 } + case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + // UDP socket. + switch udp.EndpointState(s.Endpoint.State()) { + case udp.StateInitial, udp.StateBound, udp.StateClosed: + return linux.TCP_CLOSE + case udp.StateConnected: + return linux.TCP_ESTABLISHED + default: + return 0 + } + case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + // TODO(b/112063468): Export states for ICMP sockets. + case s.skType == linux.SOCK_RAW: + // TODO(b/112063468): Export states for raw sockets. + default: + // Unknown transport protocol, how did we make this socket? + log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem()) + return 0 } - // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } -- cgit v1.2.3 From 502c47f7a70a088213faf27b60e6f62ece4dd765 Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Fri, 30 Aug 2019 17:17:45 -0700 Subject: Return correct buffer size for ioctl(socket, FIONREAD) Ioctl was returning just the buffer size from epsocket.endpoint and it was not considering data from epsocket.SocketOperations that was read from the endpoint, but not yet sent to the caller. PiperOrigin-RevId: 266485461 --- pkg/sentry/socket/epsocket/epsocket.go | 22 +++++++++++++++++++++- test/syscalls/linux/tcp_socket.cc | 21 ++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index def29646e..0e37ce61b 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -2104,7 +2104,8 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. - if int(args[1].Int()) == syscall.SIOCGSTAMP { + switch args[1].Int() { + case syscall.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2116,6 +2117,25 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, AddressSpaceActive: true, }) return 0, err + + case linux.TIOCINQ: + v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() + } + + // Add bytes removed from the endpoint but not yet sent to the caller. + v += len(s.readView) + + if v > math.MaxInt32 { + v = math.MaxInt32 + } + + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err } return Ioctl(ctx, s.Endpoint, io, args) diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 8f4d3f386..bfa031bce 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -579,7 +579,7 @@ TEST_P(TcpSocketTest, TcpInq) { if (size == sizeof(buf)) { break; } - usleep(10000); + absl::SleepFor(absl::Milliseconds(10)); } struct msghdr msg = {}; @@ -610,6 +610,25 @@ TEST_P(TcpSocketTest, TcpInq) { } } +TEST_P(TcpSocketTest, Tiocinq) { + char buf[1024]; + size_t size = sizeof(buf); + ASSERT_THAT(RetryEINTR(write)(s_, buf, size), SyscallSucceedsWithValue(size)); + + uint32_t seed = time(nullptr); + const size_t max_chunk = size / 10; + while (size > 0) { + size_t chunk = (rand_r(&seed) % max_chunk) + 1; + ssize_t read = RetryEINTR(recvfrom)(t_, buf, chunk, 0, nullptr, nullptr); + ASSERT_THAT(read, SyscallSucceeds()); + size -= read; + + int inq = 0; + ASSERT_THAT(ioctl(t_, TIOCINQ, &inq), SyscallSucceeds()); + ASSERT_EQ(inq, size); + } +} + TEST_P(TcpSocketTest, TcpSCMPriority) { char buf[1024]; ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)), -- cgit v1.2.3 From 7c6ab6a219f37a1d4c18ced4a602458fcf363f85 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 12 Sep 2019 17:42:14 -0700 Subject: Implement splice methods for pipes and sockets. This also allows the tee(2) implementation to be enabled, since dup can now be properly supported via WriteTo. Note that this change necessitated some minor restructoring with the fs.FileOperations splice methods. If the *fs.File is passed through directly, then only public API methods are accessible, which will deadlock immediately since the locking is already done by fs.Splice. Instead, we pass through an abstract io.Reader or io.Writer, which elide locks and use the underlying fs.FileOperations directly. PiperOrigin-RevId: 268805207 --- pkg/sentry/fs/file.go | 23 +++- pkg/sentry/fs/file_operations.go | 9 +- pkg/sentry/fs/file_overlay.go | 9 +- pkg/sentry/fs/fsutil/file.go | 6 +- pkg/sentry/fs/inotify.go | 5 +- pkg/sentry/fs/splice.go | 162 +++++++++++++------------- pkg/sentry/kernel/pipe/buffer.go | 25 ++++ pkg/sentry/kernel/pipe/pipe.go | 82 +++++++++++--- pkg/sentry/kernel/pipe/reader_writer.go | 76 ++++++++++++- pkg/sentry/socket/epsocket/epsocket.go | 134 +++++++++++++++++++--- pkg/sentry/syscalls/linux/linux64.go | 4 +- pkg/sentry/syscalls/linux/sys_splice.go | 86 +++++++------- pkg/tcpip/header/udp.go | 5 + pkg/tcpip/stack/transport_test.go | 4 +- pkg/tcpip/tcpip.go | 48 ++++---- pkg/tcpip/transport/icmp/endpoint.go | 4 +- pkg/tcpip/transport/raw/endpoint.go | 7 +- pkg/tcpip/transport/tcp/endpoint.go | 68 ++++++----- pkg/tcpip/transport/udp/endpoint.go | 14 +-- test/syscalls/linux/BUILD | 3 + test/syscalls/linux/pipe.cc | 14 +++ test/syscalls/linux/sendfile.cc | 69 ++++++++++++ test/syscalls/linux/splice.cc | 194 +++++++++++++++++++++++++------- 23 files changed, 770 insertions(+), 281 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index bb8117f89..c0a6e884b 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -515,6 +515,11 @@ type lockedReader struct { // File is the file to read from. File *File + + // Offset is the offset to start at. + // + // This applies only to Read, not ReadAt. + Offset int64 } // Read implements io.Reader.Read. @@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) { if r.Ctx.Interrupted() { return 0, syserror.ErrInterrupted } - n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset) + n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset) + r.Offset += n return int(n), err } @@ -544,11 +550,21 @@ type lockedWriter struct { // File is the file to write to. File *File + + // Offset is the offset to start at. + // + // This applies only to Write, not WriteAt. + Offset int64 } // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { - return w.WriteAt(buf, w.File.offset) + if w.Ctx.Interrupted() { + return 0, syserror.ErrInterrupted + } + n, err := w.WriteAt(buf, w.Offset) + w.Offset += int64(n) + return int(n), err } // WriteAt implements io.Writer.WriteAt. @@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { // io.Copy, since our own Write interface does not have this same // contract. Enforce that here. for written < len(buf) { + if w.Ctx.Interrupted() { + return written, syserror.ErrInterrupted + } var n int64 n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) if n > 0 { diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index d86f5bf45..b88303f17 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -15,6 +15,8 @@ package fs import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -105,8 +107,11 @@ type FileOperations interface { // on the destination, following by a buffered copy with standard Read // and Write operations. // + // If dup is set, the data should be duplicated into the destination + // and retained. + // // The same preconditions as Read apply. - WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error) + WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error) // Write writes src to file at offset and returns the number of bytes // written which must be greater than or equal to 0. Like Read, file @@ -126,7 +131,7 @@ type FileOperations interface { // source. See WriteTo for details regarding how this is called. // // The same preconditions as Write apply; FileFlags.Write must be set. - ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error) + ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error) // Fsync writes buffered modifications of file and/or flushes in-flight // operations to backing storage based on syncType. The range to sync is diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 9820f0b13..225e40186 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "gvisor.dev/gvisor/pkg/refs" @@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme } // WriteTo implements FileOperations.WriteTo. -func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) { err = f.onTop(ctx, file, func(file *File, ops FileOperations) error { - n, err = ops.WriteTo(ctx, file, dst, opts) + n, err = ops.WriteTo(ctx, file, dst, count, dup) return err // Will overwrite itself. }) return @@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm } // ReadFrom implements FileOperations.ReadFrom. -func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) { // See above; f.upper must be non-nil. - return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts) + return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count) } // Fsync implements FileOperations.Fsync. diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 626b9126a..fc5b3b1a1 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -15,6 +15,8 @@ package fsutil import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu type FileNoSplice struct{} // WriteTo implements fs.FileOperations.WriteTo. -func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } // ReadFrom implements fs.FileOperations.ReadFrom. -func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index c7f4e2d13..ba3e0233d 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "sync/atomic" @@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } // WriteTo implements FileOperations.WriteTo. -func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } @@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { } // ReadFrom implements FileOperations.ReadFrom. -func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index eed1c2854..b03b7f836 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -18,7 +18,6 @@ import ( "io" "sync/atomic" - "gvisor.dev/gvisor/pkg/secio" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserror" ) @@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } // Check whether or not the objects being sliced are stream-oriented - // (i.e. pipes or sockets). If yes, we elide checks and offset locks. - srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr) - dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr) + // (i.e. pipes or sockets). For all stream-oriented files and files + // where a specific offiset is not request, we acquire the file mutex. + // This has two important side effects. First, it provides the standard + // protection against concurrent writes that would mutate the offset. + // Second, it prevents Splice deadlocks. Only internal anonymous files + // implement the ReadFrom and WriteTo methods directly, and since such + // anonymous files are referred to by a unique fs.File object, we know + // that the file mutex takes strict precedence over internal locks. + // Since we enforce lock ordering here, we can't deadlock by using + // using a file in two different splice operations simultaneously. + srcPipe := !IsRegular(src.Dirent.Inode.StableAttr) + dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr) + dstAppend := !dstPipe && dst.Flags().Append + srcLock := srcPipe || !opts.SrcOffset + dstLock := dstPipe || !opts.DstOffset || dstAppend - if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset { + switch { + case srcLock && dstLock: switch { case dst.UniqueID < src.UniqueID: // Acquire dst first. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() if !src.mu.Lock(ctx) { + dst.mu.Unlock() return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() case dst.UniqueID > src.UniqueID: // Acquire src first. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() if !dst.mu.Lock(ctx) { + src.mu.Unlock() return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() case dst.UniqueID == src.UniqueID: // Acquire only one lock; it's the same file. This is a // bit of a edge case, but presumably it's possible. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() + srcLock = false // Only need one unlock. } // Use both offsets (locked). opts.DstStart = dst.offset opts.SrcStart = src.offset - } else if !dstPipe && !opts.DstOffset { + case dstLock: // Acquire only dst. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() opts.DstStart = dst.offset // Safe: locked. - } else if !srcPipe && !opts.SrcOffset { + case srcLock: // Acquire only src. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() opts.SrcStart = src.offset // Safe: locked. } - // Check append-only mode and the limit. - if !dstPipe { + var err error + if dstAppend { unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append) defer unlock() - if dst.Flags().Append { - if opts.DstOffset { - // We need to acquire the lock. - if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted - } - defer dst.mu.Unlock() - } - // Figure out the appropriate offset to use. - if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil { - return 0, err - } - } + // Figure out the appropriate offset to use. + err = dst.offsetForAppend(ctx, &opts.DstStart) + } + if err == nil && !dstPipe { // Enforce file limits. limit, ok := dst.checkLimit(ctx, opts.DstStart) switch { case ok && limit == 0: - return 0, syserror.ErrExceedsFileSizeLimit + err = syserror.ErrExceedsFileSizeLimit case ok && limit < opts.Length: opts.Length = limit // Cap the write. } } + if err != nil { + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return 0, err + } - // Attempt to do a WriteTo; this is likely the most efficient. - // - // The underlying implementation may be able to donate buffers. - newOpts := SpliceOpts{ - Length: opts.Length, - SrcStart: opts.SrcStart, - SrcOffset: !srcPipe, - Dup: opts.Dup, - DstStart: opts.DstStart, - DstOffset: !dstPipe, + // Construct readers and writers for the splice. This is used to + // provide a safer locking path for the WriteTo/ReadFrom operations + // (since they will otherwise go through public interface methods which + // conflict with locking done above), and simplifies the fallback path. + w := &lockedWriter{ + Ctx: ctx, + File: dst, + Offset: opts.DstStart, } - n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts) - if n == 0 && err != nil { - // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also - // be more efficient than a copy if buffers are cached or readily - // available. (It's unlikely that they can actually be donate - n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts) + r := &lockedReader{ + Ctx: ctx, + File: src, + Offset: opts.SrcStart, } - if n == 0 && err != nil { - // If we've failed up to here, and at least one of the sources - // is a pipe or socket, then we can't properly support dup. - // Return an error indicating that this operation is not - // supported. - if (srcPipe || dstPipe) && newOpts.Dup { - return 0, syserror.EINVAL - } - // We failed to splice the files. But that's fine; we just fall - // back to a slow path in this case. This copies without doing - // any mode changes, so should still be more efficient. - var ( - r io.Reader - w io.Writer - ) - fw := &lockedWriter{ - Ctx: ctx, - File: dst, - } - if newOpts.DstOffset { - // Use the provided offset. - w = secio.NewOffsetWriter(fw, newOpts.DstStart) - } else { - // Writes will proceed with no offset. - w = fw - } - fr := &lockedReader{ - Ctx: ctx, - File: src, - } - if newOpts.SrcOffset { - // Limit to the given offset and length. - r = io.NewSectionReader(fr, opts.SrcStart, opts.Length) - } else { - // Limit just to the given length. - r = &io.LimitedReader{fr, opts.Length} - } + // Attempt to do a WriteTo; this is likely the most efficient. + n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup { + // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be + // more efficient than a copy if buffers are cached or readily + // available. (It's unlikely that they can actually be donated). + n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length) + } - // Copy between the two. - n, err = io.Copy(w, r) + // Support one last fallback option, but only if at least one of + // the source and destination are regular files. This is because + // if we block at some point, we could lose data. If the source is + // not a pipe then reading is not destructive; if the destination + // is a regular file, then it is guaranteed not to block writing. + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) { + // Fallback to an in-kernel copy. + n, err = io.Copy(w, &io.LimitedReader{ + R: r, + N: opts.Length, + }) } // Update offsets, if required. @@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } } + // Drop locks. + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return n, err } diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 69ef2a720..95bee2d37 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" @@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } +// WriteFromReader writes to the buffer from an io.Reader. +func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { + dst := b.data[b.write:] + if count < int64(len(dst)) { + dst = b.data[b.write:][:count] + } + n, err := r.Read(dst) + b.write += n + return int64(n), err +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) @@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return n, err } +// ReadToWriter reads from the buffer into an io.Writer. +func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { + src := b.data[b.read:b.write] + if count < int64(len(src)) { + src = b.data[b.read:][:count] + } + n, err := w.Write(src) + if !dup { + b.read += n + } + return int64(n), err +} + // bufferPool is a pool for buffers. var bufferPool = sync.Pool{ New: func() interface{} { diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 247e2928e..93b50669f 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } +type readOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit limits subsequence reads. + limit func(int64) + + // read performs the actual read operation. + read func(*buffer) (int64, error) +} + // read reads data from the pipe into dst and returns the number of bytes // read, or returns ErrWouldBlock if the pipe is empty. // // Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) { +func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if dst.NumBytes() == 0 { + if ops.left() == 0 { return 0, nil } @@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Limit how much we consume. - if dst.NumBytes() > p.size { - dst = dst.TakeFirst64(p.size) + if ops.left() > p.size { + ops.limit(p.size) } done := int64(0) - for dst.NumBytes() > 0 { + for ops.left() > 0 { // Pop the first buffer. first := p.data.Front() if first == nil { @@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Copy user data. - n, err := dst.CopyOutFrom(ctx, first) + n, err := ops.read(first) done += int64(n) p.size -= n - dst = dst.DropFirst64(n) // Empty buffer? if first.Empty() { @@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) return done, nil } +// dup duplicates all data from this pipe into the given writer. +// +// There is no blocking behavior implemented here. The writer may propagate +// some blocking error. All the writes must be complete writes. +func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Is the pipe empty? + if p.size == 0 { + if !p.HasWriters() { + // See above. + return 0, nil + } + return 0, syserror.ErrWouldBlock + } + + // Limit how much we consume. + if ops.left() > p.size { + ops.limit(p.size) + } + + done := int64(0) + for buf := p.data.Front(); buf != nil; buf = buf.Next() { + n, err := ops.read(buf) + done += n + if err != nil { + return done, err + } + } + + return done, nil +} + +type writeOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit should limit subsequent writes. + limit func(int64) + + // write should write to the provided buffer. + write func(*buffer) (int64, error) +} + // write writes data from sv into the pipe and returns the number of bytes // written. If no bytes are written because the pipe is full (or has less than // atomicIOBytes free capacity), write returns ErrWouldBlock. // // Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) { +func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() @@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := src.NumBytes() + wanted := ops.left() if avail := p.max - p.size; wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } - // Limit to the available capacity. - src = src.TakeFirst64(avail) + ops.limit(avail) } done := int64(0) - for src.NumBytes() > 0 { + for ops.left() > 0 { // Need a new buffer? last := p.data.Back() if last == nil || last.Full() { @@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) } // Copy user data. - n, err := src.CopyInTo(ctx, last) + n, err := ops.write(last) done += int64(n) p.size += n - src = src.DropFirst64(n) // Handle errors. if err != nil { diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index f69dbf27b..7c307f013 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "math" "syscall" @@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() { // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, dst) + n, err := rw.Pipe.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return rw.Pipe.dup(ctx, ops) + } + n, err := rw.Pipe.read(ctx, ops) if n > 0 { rw.Pipe.Notify(waiter.EventOut) } @@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, src) + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) if n > 0 { rw.Pipe.Notify(waiter.EventIn) } diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 0e37ce61b..3e05e40fe 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -26,6 +26,7 @@ package epsocket import ( "bytes" + "io" "math" "reflect" "sync" @@ -227,7 +228,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -412,17 +412,58 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + defer s.readMu.Unlock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + return done, nil + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + return done, err + } + } + + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -431,11 +472,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -469,6 +505,76 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -2060,7 +2166,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2085,7 +2191,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ed996ba51..150999fb8 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -320,8 +320,8 @@ var AMD64 = &kernel.SyscallTable{ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 8a98fedcb..f0a292f2f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB total int64 n int64 err error - ch chan struct{} - inW bool - outW bool + inCh chan struct{} + outCh chan struct{} ) for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) @@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB break } - // Are we a registered waiter? - if ch == nil { - ch = make(chan struct{}, 1) - } - if !inW && !inFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - inFile.EventRegister(&w, EventMaskRead) - defer inFile.EventUnregister(&w) - inW = true // Registered. - } else if !outW && !outFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - outFile.EventRegister(&w, EventMaskWrite) - defer outFile.EventUnregister(&w) - outW = true // Registered. - } - - // Was anything registered? If no, everything is non-blocking. - if !inW && !outW { - break - } - - if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) { - // Something became ready, try again without blocking. - continue + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(EventMaskRead) == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, EventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } } - - // Block until there's data. - if err = t.Block(ch); err != nil { - break + if outFile.Readiness(EventMaskWrite) == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } } } @@ -149,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Length: count, SrcOffset: true, SrcStart: offset, - }, false) + }, outFile.Flags().NonBlocking) // Copy out the new offset. if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { @@ -159,7 +156,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Send data using splice. n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, - }, false) + }, outFile.Flags().NonBlocking) } // We can only pass a single file to handleIOError, so pick inFile @@ -181,12 +178,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. Note that unlike in Linux, this - // flag is applied consistently. We will have either fully blocking or - // non-blocking behavior below, regardless of the underlying files - // being spliced to. It's unclear if this is a bug or not yet. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -200,6 +191,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer inFile.DecRef() + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Construct our options. // // Note that exactly one of the underlying buffers must be a pipe. We @@ -257,7 +255,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Splice data. - n, err := doSplice(t, outFile, inFile, opts, nonBlocking) + n, err := doSplice(t, outFile, inFile, opts, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) @@ -275,9 +273,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -301,11 +296,14 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } + // The operation is non-blocking if anything is non-blocking. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Splice data. n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, Dup: true, - }, nonBlocking) + }, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index c1f454805..74412c894 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -27,6 +27,11 @@ const ( udpChecksum = 6 ) +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 87d1e0d0d..847d02982 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ebf8a2d04..2534069ab 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -261,31 +261,34 @@ type FullAddress struct { Port uint16 } -// Payload provides an interface around data that is being sent to an endpoint. -// This allows the endpoint to request the amount of data it needs based on -// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data. -type Payload interface { - // Get returns a slice containing exactly 'min(size, p.Size())' bytes. - Get(size int) ([]byte, *Error) - - // Size returns the payload size. - Size() int +// Payloader is an interface that provides data. +// +// This interface allows the endpoint to request the amount of data it needs +// based on internal buffers without exposing them. +type Payloader interface { + // FullPayload returns all available bytes. + FullPayload() ([]byte, *Error) + + // Payload returns a slice containing at most size bytes. + Payload(size int) ([]byte, *Error) } -// SlicePayload implements Payload on top of slices for convenience. +// SlicePayload implements Payloader for slices. +// +// This is typically used for tests. type SlicePayload []byte -// Get implements Payload. -func (s SlicePayload) Get(size int) ([]byte, *Error) { - if size > s.Size() { - size = s.Size() - } - return s[:size], nil +// FullPayload implements Payloader.FullPayload. +func (s SlicePayload) FullPayload() ([]byte, *Error) { + return s, nil } -// Size implements Payload. -func (s SlicePayload) Size() int { - return len(s) +// Payload implements Payloader.Payload. +func (s SlicePayload) Payload(size int) ([]byte, *Error) { + if size > len(s) { + size = len(s) + } + return s[:size], nil } // A ControlMessages contains socket control messages for IP sockets. @@ -338,7 +341,7 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // @@ -432,6 +435,11 @@ type WriteOptions struct { // EndOfRecord has the same semantics as Linux's MSG_EOR. EndOfRecord bool + + // Atomic means that all data fetched from Payloader must be written to the + // endpoint. If Atomic is false, then data fetched from the Payloader may be + // discarded if available endpoint buffer space is unsufficient. + Atomic bool } // SockOpt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e1f622af6..3db060384 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -204,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -289,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 13e17e2a6..cf1c5c433 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 return 0, nil, tcpip.ErrInvalidEndpointState } - payloadBytes, err := payload.Get(payload.Size()) + payloadBytes, err := p.FullPayload() if err != nil { - ep.mu.RUnlock() return 0, nil, err } @@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 // destination address, route using that address. if !ep.associated { ip := header.IPv4(payloadBytes) - if !ip.IsValid(payload.Size()) { + if !ip.IsValid(len(payloadBytes)) { ep.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ac927569a..dd931f88c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -806,7 +806,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -821,47 +821,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, err } - e.sndBufMu.Unlock() - e.mu.RUnlock() - - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.mu.RUnlock() } - // Copy in memory without holding sndBufMu so that worker goroutine can - // make progress independent of this operation. - v, perr := p.Get(avail) - if perr != nil { + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + if opts.Atomic { // See above. + e.sndBufMu.Unlock() + e.mu.RUnlock() + } + // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - e.mu.RLock() - e.sndBufMu.Lock() + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a - // write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - return 0, nil, err - } + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } - // Discard any excess data copied in due to avail being reduced due to a - // simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } // Add data to the send queue. - l := len(v) s := newSegmentFromView(&e.route, e.id, v) - e.sndBufUsed += l - e.sndBufInQueue += seqnum.Size(l) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() // Release the endpoint lock to prevent deadlocks due to lock // order inversion when acquiring workMu. @@ -875,7 +880,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return int64(l), nil, nil + + return int64(len(v)), nil, nil } // Peek reads data without consuming it from the endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index dccb9a7eb..6ac7c067a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -277,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -370,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 34057e3d0..df00d2c14 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1867,7 +1867,9 @@ cc_binary( "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) @@ -1901,6 +1903,7 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 65afb90f3..10e2a6dfc 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -168,6 +168,20 @@ TEST_P(PipeTest, Write) { EXPECT_EQ(wbuf, rbuf); } +TEST_P(PipeTest, WritePage) { + SKIP_IF(!CreateBlocking()); + + std::vector wbuf(kPageSize); + RandomizeBuffer(wbuf.data(), wbuf.size()); + std::vector rbuf(wbuf.size()); + + ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()), + SyscallSucceedsWithValue(wbuf.size())); + ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0); +} + TEST_P(PipeTest, NonBlocking) { SKIP_IF(!CreateNonBlocking()); diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 9167ab066..4502e7fb4 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -19,9 +19,12 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -442,6 +445,72 @@ TEST(SendFileTest, SendToNotARegularFile) { EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0), SyscallFailsWithErrno(EINVAL)); } + +TEST(SendFileTest, SendPipeWouldBlock) { + // Create temp file. + constexpr char kData[] = + "The fool doth think he is wise, but the wise man knows himself to be a " + "fool."; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill up the pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector buf(2 * pipe_size); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST(SendFileTest, SendPipeBlocks) { + // Create temp file. + constexpr char kData[] = + "The fault, dear Brutus, is not in our stars, but in ourselves."; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill up the pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector buf(pipe_size); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + }); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), + SyscallSucceedsWithValue(kDataSize)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index e25f264f6..85232cb1f 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -14,12 +14,16 @@ #include #include +#include #include +#include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -36,23 +40,23 @@ TEST(SpliceTest, TwoRegularFiles) { const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); // Open the input file as read only. - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); // Open the output file as write only. - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); // Verify that it is rejected as expected; regardless of offsets. loff_t in_offset = 0; loff_t out_offset = 0; - EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), &out_offset, 1, 0), + EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), &out_offset, 1, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), nullptr, 1, 0), + EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), nullptr, 1, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0), SyscallFailsWithErrno(EINVAL)); } @@ -75,8 +79,6 @@ TEST(SpliceTest, SamePipe) { } TEST(TeeTest, SamePipe) { - SKIP_IF(IsRunningOnGvisor()); - // Create a new pipe. int fds[2]; ASSERT_THAT(pipe(fds), SyscallSucceeds()); @@ -95,11 +97,9 @@ TEST(TeeTest, SamePipe) { } TEST(TeeTest, RegularFile) { - SKIP_IF(IsRunningOnGvisor()); - // Open some file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Create a new pipe. @@ -109,9 +109,9 @@ TEST(TeeTest, RegularFile) { const FileDescriptor wfd(fds[1]); // Attempt to tee from the file. - EXPECT_THAT(tee(inf.get(), wfd.get(), kPageSize, 0), + EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(tee(rfd.get(), inf.get(), kPageSize, 0), + EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0), SyscallFailsWithErrno(EINVAL)); } @@ -142,7 +142,7 @@ TEST(SpliceTest, FromEventFD) { constexpr uint64_t kEventFDValue = 1; int efd; ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds()); - const FileDescriptor inf(efd); + const FileDescriptor in_fd(efd); // Create a new pipe. int fds[2]; @@ -152,7 +152,7 @@ TEST(SpliceTest, FromEventFD) { // Splice 8-byte eventfd value to pipe. constexpr int kEventFDSize = 8; - EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), SyscallSucceedsWithValue(kEventFDSize)); // Contents should be equal. @@ -166,7 +166,7 @@ TEST(SpliceTest, FromEventFD) { TEST(SpliceTest, FromEventFDOffset) { int efd; ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor inf(efd); + const FileDescriptor in_fd(efd); // Create a new pipe. int fds[2]; @@ -179,7 +179,7 @@ TEST(SpliceTest, FromEventFDOffset) { // This is not allowed because eventfd doesn't support pread. constexpr int kEventFDSize = 8; loff_t in_off = 0; - EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), + EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), SyscallFailsWithErrno(EINVAL)); } @@ -200,28 +200,29 @@ TEST(SpliceTest, ToEventFDOffset) { int efd; ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor outf(efd); + const FileDescriptor out_fd(efd); // Attempt to splice 8-byte eventfd value to pipe with offset. // // This is not allowed because eventfd doesn't support pwrite. loff_t out_off = 0; - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0), - SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0), + SyscallFailsWithErrno(EINVAL)); } TEST(SpliceTest, ToPipe) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Fill with some random data. std::vector buf(kPageSize); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(inf.get(), buf.data(), buf.size()), + ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(inf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); // Create a new pipe. int fds[2]; @@ -230,7 +231,7 @@ TEST(SpliceTest, ToPipe) { const FileDescriptor wfd(fds[1]); // Splice to the pipe. - EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), SyscallSucceedsWithValue(kPageSize)); // Contents should be equal. @@ -243,13 +244,13 @@ TEST(SpliceTest, ToPipe) { TEST(SpliceTest, ToPipeOffset) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Fill with some random data. std::vector buf(kPageSize); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(inf.get(), buf.data(), buf.size()), + ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); // Create a new pipe. @@ -261,7 +262,7 @@ TEST(SpliceTest, ToPipeOffset) { // Splice to the pipe. loff_t in_offset = kPageSize / 2; EXPECT_THAT( - splice(inf.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0), + splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0), SyscallSucceedsWithValue(kPageSize / 2)); // Contents should be equal to only the second part. @@ -286,22 +287,22 @@ TEST(SpliceTest, FromPipe) { // Open the input file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); // Splice to the output file. - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), nullptr, kPageSize, 0), + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0), SyscallSucceedsWithValue(kPageSize)); // The offset of the output should be equal to kPageSize. We assert that and // reset to zero so that we can read the contents and ensure they match. - EXPECT_THAT(lseek(outf.get(), 0, SEEK_CUR), + EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(outf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); // Contents should be equal. std::vector rbuf(kPageSize); - ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()), + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), SyscallSucceedsWithValue(kPageSize)); EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } @@ -321,18 +322,19 @@ TEST(SpliceTest, FromPipeOffset) { // Open the input file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); // Splice to the output file. loff_t out_offset = kPageSize / 2; - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_offset, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); // Content should reflect the splice. We write to a specific offset in the // file, so the internals should now be allocated sparsely. std::vector rbuf(kPageSize); - ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()), + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), SyscallSucceedsWithValue(kPageSize)); std::vector zbuf(kPageSize / 2); memset(zbuf.data(), 0, zbuf.size()); @@ -404,8 +406,6 @@ TEST(SpliceTest, Blocking) { } TEST(TeeTest, Blocking) { - SKIP_IF(IsRunningOnGvisor()); - // Create two new pipes. int first[2], second[2]; ASSERT_THAT(pipe(first), SyscallSucceeds()); @@ -440,6 +440,49 @@ TEST(TeeTest, Blocking) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); } +TEST(TeeTest, BlockingWrite) { + // Create two new pipes. + int first[2], second[2]; + ASSERT_THAT(pipe(first), SyscallSucceeds()); + const FileDescriptor rfd1(first[0]); + const FileDescriptor wfd1(first[1]); + ASSERT_THAT(pipe(second), SyscallSucceeds()); + const FileDescriptor rfd2(second[0]); + const FileDescriptor wfd2(second[1]); + + // Make some data available to be read. + std::vector buf1(kPageSize); + RandomizeBuffer(buf1.data(), buf1.size()); + ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Fill up the write pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector buf2(pipe_size); + ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()), + SyscallSucceedsWithValue(pipe_size)); + + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()), + SyscallSucceedsWithValue(pipe_size)); + }); + + // Attempt a tee immediately; it should block. + EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); + + // Thread should be joinable. + t.Join(); + + // Content should reflect the tee. + std::vector rbuf(kPageSize); + ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0); +} + TEST(SpliceTest, NonBlocking) { // Create two new pipes. int first[2], second[2]; @@ -457,8 +500,6 @@ TEST(SpliceTest, NonBlocking) { } TEST(TeeTest, NonBlocking) { - SKIP_IF(IsRunningOnGvisor()); - // Create two new pipes. int first[2], second[2]; ASSERT_THAT(pipe(first), SyscallSucceeds()); @@ -473,6 +514,79 @@ TEST(TeeTest, NonBlocking) { SyscallFailsWithErrno(EAGAIN)); } +TEST(TeeTest, MultiPage) { + // Create two new pipes. + int first[2], second[2]; + ASSERT_THAT(pipe(first), SyscallSucceeds()); + const FileDescriptor rfd1(first[0]); + const FileDescriptor wfd1(first[1]); + ASSERT_THAT(pipe(second), SyscallSucceeds()); + const FileDescriptor rfd2(second[0]); + const FileDescriptor wfd2(second[1]); + + // Make some data available to be read. + std::vector wbuf(8 * kPageSize); + RandomizeBuffer(wbuf.data(), wbuf.size()); + ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()), + SyscallSucceedsWithValue(wbuf.size())); + + // Attempt a tee immediately; it should complete. + EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0), + SyscallSucceedsWithValue(wbuf.size())); + + // Content should reflect the tee. + std::vector rbuf(wbuf.size()); + ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); + ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); +} + +TEST(SpliceTest, FromPipeMaxFileSize) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Open the input file. + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor out_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); + + EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds()); + EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END), + SyscallSucceedsWithValue(13 << 20)); + + // Set our file size limit. + sigset_t set; + sigemptyset(&set); + sigaddset(&set, SIGXFSZ); + TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); + rlimit rlim = {}; + rlim.rlim_cur = rlim.rlim_max = (13 << 20); + EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds()); + + // Splice to the output file. + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0), + SyscallFailsWithErrno(EFBIG)); + + // Contents should be equal. + std::vector rbuf(kPageSize); + ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); +} + } // namespace } // namespace testing -- cgit v1.2.3 From 75781ab3efa7b377c6dc4cf26840323f504d5eb5 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 19 Sep 2019 13:38:14 -0700 Subject: Remove defer from hot path and ensure Atomic is applied consistently. PiperOrigin-RevId: 270114317 --- pkg/sentry/socket/epsocket/epsocket.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 3e05e40fe..25adca090 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -415,13 +415,13 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // WriteTo implements fs.FileOperations.WriteTo. func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { s.readMu.Lock() - defer s.readMu.Unlock() // Copy as much data as possible. done := int64(0) for count > 0 { // This may return a blocking error. if err := s.fetchReadView(); err != nil { + s.readMu.Unlock() return done, err.ToError() } @@ -434,16 +434,18 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write // supported by any Linux system calls, but the // expectation is that now a caller will call read to // actually remove these bytes from the socket. - return done, nil + break } // Drop that part of the view. s.readView.TrimFront(n) if err != nil { + s.readMu.Unlock() return done, err } } + s.readMu.Unlock() return done, nil } @@ -549,7 +551,11 @@ func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { f := &readerPayload{ctx: ctx, r: r, count: count} - n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -561,9 +567,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ - // Reads may be destructive but should be very fast, - // so we can't release the lock while copying data. - Atomic: true, + Atomic: true, // See above. }) } if err == tcpip.ErrWouldBlock { -- cgit v1.2.3 From 03ee55cc62c99c5b8f5d6fb00423a66ef44589e3 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Mon, 23 Sep 2019 14:37:39 -0700 Subject: netstack: convert more socket options to {Set,Get}SockOptInt PiperOrigin-RevId: 270763208 --- pkg/sentry/socket/epsocket/epsocket.go | 22 ++-- pkg/sentry/socket/unix/transport/unix.go | 82 ++++++------ pkg/tcpip/stack/transport_test.go | 5 + pkg/tcpip/tcpip.go | 32 +++-- pkg/tcpip/transport/icmp/endpoint.go | 29 +++-- pkg/tcpip/transport/raw/endpoint.go | 30 +++-- pkg/tcpip/transport/tcp/endpoint.go | 144 +++++++++++---------- pkg/tcpip/transport/tcp/tcp_noracedetector_test.go | 10 +- pkg/tcpip/transport/tcp/tcp_test.go | 121 ++++++++--------- pkg/tcpip/transport/tcp/testing/context/context.go | 8 +- pkg/tcpip/transport/udp/endpoint.go | 32 +++-- 11 files changed, 276 insertions(+), 239 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 25adca090..3e66f9cbb 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -209,6 +209,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -887,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -903,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1275,7 +1279,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1283,7 +1287,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -2317,9 +2321,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2b0ad6395..1867b3a5c 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has + // the int type. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error @@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrQueueSizeNotSupported } return v, nil - default: - return -1, tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - case *tcpip.SendQueueSizeOption: + case tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v := e.connected.SendQueuedSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported - } - *o = qs - return nil - - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - return nil + return int(v), nil - case *tcpip.SendBufferSizeOption: + case tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + v := e.connected.SendMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs - return nil + return int(v), nil - case *tcpip.ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + v := e.receiver.RecvMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return int(v), nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) } - *o = qs return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 847d02982..0e69ac7c8 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2534069ab..c021c67ac 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -401,6 +401,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptInt sets a socket option, for simple cases where a value + // has the int type. + SetSockOptInt(opt SockOpt, v int) *Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error @@ -446,10 +450,22 @@ type WriteOptions struct { type SockOpt int const ( - // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of - // unread bytes in the input buffer should be returned. + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption SockOpt = iota + // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the send buffer size option. + SendBufferSizeOption + + // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the receive buffer size option. + ReceiveBufferSizeOption + + // SendQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the output buffer should be returned. + SendQueueSizeOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -458,18 +474,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send -// buffer size option. -type SendBufferSizeOption int - -// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the -// receive buffer size option. -type ReceiveBufferSizeOption int - -// SendQueueSizeOption is used in GetSockOpt to specify that the number of -// unread bytes in the output buffer should be returned. -type SendQueueSizeOption int - // V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3db060384..a111fdb2a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -319,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -331,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -341,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index cf1c5c433..a02731a5d 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -492,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -504,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } ep.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSize + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption @@ -515,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - ep.mu.Lock() - *o = tcpip.SendBufferSizeOption(ep.sndBufSize) - ep.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) - ep.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index dd931f88c..35b489c68 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -952,62 +952,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil - - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - +// SetSockOptInt sets a socket option. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1071,6 +1018,67 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sndBufMu.Unlock() return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + + case tcpip.QuickAckOption: + if v == 0 { + atomic.StoreUint32(&e.slowAck, 1) + } else { + atomic.StoreUint32(&e.slowAck, 0) + } + return nil + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { @@ -1182,6 +1190,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + v := e.sndBufSize + e.sndBufMu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + v := e.rcvBufSize + e.rcvListMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -1204,18 +1224,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.sndBufMu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) - e.rcvListMu.Unlock() - return nil - case *tcpip.DelayOption: *o = 0 if v := atomic.LoadUint32(&e.delay); v != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 272bbcdbd..9fa97528b 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { enableCUBIC(t, c) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 32bb45224..7fa5cfb6e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) @@ -299,7 +299,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -323,7 +323,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -344,14 +344,14 @@ func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -367,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -417,7 +417,7 @@ func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -469,7 +469,7 @@ func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -557,8 +557,7 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() // Create a new connection with initial window size of 10. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -631,7 +630,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -700,7 +699,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -785,7 +784,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -804,8 +803,7 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -872,11 +870,9 @@ func TestNoWindowShrinking(t *testing.T) { defer c.Cleanup() // Start off with a window size of 10, then shrink it to 5. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) - opt = 5 - if err := c.EP.SetSockOpt(opt); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { t.Fatalf("SetSockOpt failed: %v", err) } @@ -976,7 +972,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1017,7 +1013,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, nil) + c.CreateConnected(789, 0, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1075,8 +1071,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1110,8 +1105,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1151,7 +1145,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1224,7 +1218,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1293,8 +1287,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the window size such that a window scale of 4 will be used. const wnd = 65535 * 10 const ws = uint32(4) - opt := tcpip.ReceiveBufferSizeOption(wnd) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1399,7 +1392,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Prevent the endpoint from processing packets. test.stop(c.EP) @@ -1449,7 +1442,7 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1497,7 +1490,7 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1579,7 +1572,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -1695,7 +1688,7 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } @@ -1704,7 +1697,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -1727,7 +1720,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1871,7 +1864,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 2 - if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1973,7 +1966,7 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2010,7 +2003,7 @@ func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2035,7 +2028,7 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2078,7 +2071,7 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2132,7 +2125,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := buffer.NewView(10) @@ -2203,7 +2196,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -2291,7 +2284,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) @@ -2377,7 +2370,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -2509,7 +2502,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) @@ -2559,7 +2552,7 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 @@ -2580,7 +2573,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 vv := c.BuildSegment(nil, &context.Headers{ @@ -2604,7 +2597,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ @@ -2635,7 +2628,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send 200 segments. data := []byte{1, 2, 3} @@ -2681,7 +2674,7 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -2856,8 +2849,8 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2869,8 +2862,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2945,26 +2938,26 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -3231,7 +3224,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -3308,7 +3301,7 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -3482,7 +3475,7 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 16783e716..78eff5c3a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -512,7 +512,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { } // CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } @@ -590,7 +590,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -598,8 +598,8 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("NewEndpoint failed: %v", err) } - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + if epRcvBuf != -1 { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6ac7c067a..0bec7e62d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -389,7 +389,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: @@ -568,7 +573,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil } + return -1, tcpip.ErrUnknownProtocolOption } @@ -578,18 +596,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { -- cgit v1.2.3 From 543492650dd528c1d837d788dcd3b5138e8dc1c0 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 26 Sep 2019 15:07:59 -0700 Subject: Make raw socket tests pass in environments with or without CAP_NET_RAW. PiperOrigin-RevId: 271442321 --- pkg/sentry/socket/epsocket/provider.go | 2 +- test/syscalls/linux/packet_socket.cc | 29 ++++++++++++++------- test/syscalls/linux/packet_socket_raw.cc | 21 ++++++++------- test/syscalls/linux/raw_socket_hdrincl.cc | 43 +++++++------------------------ test/syscalls/linux/raw_socket_icmp.cc | 13 +++++++--- test/syscalls/linux/raw_socket_ipv4.cc | 13 +++++++--- 6 files changed, 59 insertions(+), 62 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index 421f93dc4..0a9dfa6c3 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { - return 0, true, syserr.ErrPermissionDenied + return 0, true, syserr.ErrNotPermitted } switch protocol { diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 7a3379b9e..37b4e6575 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -83,9 +83,15 @@ void SendUDPMessage(int sock) { // Send an IP packet and make sure ETH_P_ doesn't pick it up. TEST(BasicCookedPacketTest, WrongType) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // (b/129292371): Remove once we support packet sockets. SKIP_IF(IsRunningOnGvisor()); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } + FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP)); @@ -118,18 +124,27 @@ class CookedPacketTest : public ::testing::TestWithParam { }; void CookedPacketTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // (b/129292371): Remove once we support packet sockets. SKIP_IF(IsRunningOnGvisor()); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), SyscallSucceeds()); } void CookedPacketTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // (b/129292371): Remove once we support packet sockets. SKIP_IF(IsRunningOnGvisor()); - EXPECT_THAT(close(socket_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(socket_), SyscallSucceeds()); + } } int CookedPacketTest::GetLoopbackIndex() { @@ -142,9 +157,6 @@ int CookedPacketTest::GetLoopbackIndex() { // Receive via a packet socket. TEST_P(CookedPacketTest, Receive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - // Let's use a simple IP payload: a UDP datagram. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); @@ -201,9 +213,6 @@ TEST_P(CookedPacketTest, Receive) { // Send via a packet socket. TEST_P(CookedPacketTest, Send) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - // Let's send a UDP packet and receive it using a regular UDP socket. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 9e96460ee..6491453b6 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -97,9 +97,15 @@ class RawPacketTest : public ::testing::TestWithParam { }; void RawPacketTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // (b/129292371): Remove once we support packet sockets. SKIP_IF(IsRunningOnGvisor()); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } + if (!IsRunningOnGvisor()) { FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); @@ -119,10 +125,13 @@ void RawPacketTest::SetUp() { } void RawPacketTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // (b/129292371): Remove once we support packet sockets. SKIP_IF(IsRunningOnGvisor()); - EXPECT_THAT(close(socket_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(socket_), SyscallSucceeds()); + } } int RawPacketTest::GetLoopbackIndex() { @@ -135,9 +144,6 @@ int RawPacketTest::GetLoopbackIndex() { // Receive via a packet socket. TEST_P(RawPacketTest, Receive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - // Let's use a simple IP payload: a UDP datagram. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); @@ -208,9 +214,6 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - // Let's send a UDP packet and receive it using a regular UDP socket. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index a070817eb..0a27506aa 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -63,7 +63,11 @@ class RawHDRINCL : public ::testing::Test { }; void RawHDRINCL::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds()); @@ -76,9 +80,10 @@ void RawHDRINCL::SetUp() { } void RawHDRINCL::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - EXPECT_THAT(close(socket_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(socket_), SyscallSucceeds()); + } } struct iphdr RawHDRINCL::LoopbackHeader() { @@ -123,8 +128,6 @@ bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port, // We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup // creates the first one, so we only have to create one more here. TEST_F(RawHDRINCL, MultipleCreation) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int s2; ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds()); @@ -133,23 +136,17 @@ TEST_F(RawHDRINCL, MultipleCreation) { // Test that shutting down an unconnected socket fails. TEST_F(RawHDRINCL, FailShutdownWithoutConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); } // Test that listen() fails. TEST_F(RawHDRINCL, FailListen) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP)); } // Test that accept() fails. TEST_F(RawHDRINCL, FailAccept) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct sockaddr saddr; socklen_t addrlen; ASSERT_THAT(accept(socket_, &saddr, &addrlen), @@ -158,8 +155,6 @@ TEST_F(RawHDRINCL, FailAccept) { // Test that the socket is writable immediately. TEST_F(RawHDRINCL, PollWritableImmediately) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct pollfd pfd = {}; pfd.fd = socket_; pfd.events = POLLOUT; @@ -168,8 +163,6 @@ TEST_F(RawHDRINCL, PollWritableImmediately) { // Test that the socket isn't readable. TEST_F(RawHDRINCL, NotReadable) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - // Try to receive data with MSG_DONTWAIT, which returns immediately if there's // nothing to be read. char buf[117]; @@ -179,16 +172,12 @@ TEST_F(RawHDRINCL, NotReadable) { // Test that we can connect() to a valid IP (loopback). TEST_F(RawHDRINCL, ConnectToLoopback) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(connect(socket_, reinterpret_cast(&addr_), sizeof(addr_)), SyscallSucceeds()); } TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct iphdr hdr = LoopbackHeader(); ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), SyscallSucceedsWithValue(sizeof(hdr))); @@ -197,8 +186,6 @@ TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { // HDRINCL implies write-only. Verify that we can't read a packet sent to // loopback. TEST_F(RawHDRINCL, NotReadableAfterWrite) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(connect(socket_, reinterpret_cast(&addr_), sizeof(addr_)), SyscallSucceeds()); @@ -221,8 +208,6 @@ TEST_F(RawHDRINCL, NotReadableAfterWrite) { } TEST_F(RawHDRINCL, WriteTooSmall) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(connect(socket_, reinterpret_cast(&addr_), sizeof(addr_)), SyscallSucceeds()); @@ -235,8 +220,6 @@ TEST_F(RawHDRINCL, WriteTooSmall) { // Bind to localhost. TEST_F(RawHDRINCL, BindToLocalhost) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT( bind(socket_, reinterpret_cast(&addr_), sizeof(addr_)), SyscallSucceeds()); @@ -244,8 +227,6 @@ TEST_F(RawHDRINCL, BindToLocalhost) { // Bind to a different address. TEST_F(RawHDRINCL, BindToInvalid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct sockaddr_in bind_addr = {}; bind_addr.sin_family = AF_INET; bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. @@ -256,8 +237,6 @@ TEST_F(RawHDRINCL, BindToInvalid) { // Send and receive a packet. TEST_F(RawHDRINCL, SendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast(ASSERT_NO_ERRNO_AND_VALUE( @@ -302,8 +281,6 @@ TEST_F(RawHDRINCL, SendAndReceive) { // Send and receive a packet with nonzero IP ID. TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast(ASSERT_NO_ERRNO_AND_VALUE( @@ -349,8 +326,6 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { // Send and receive a packet where the sendto address is not the same as the // provided destination. TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast(ASSERT_NO_ERRNO_AND_VALUE( diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 971592d7d..8bcaba6f1 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -77,7 +77,11 @@ class RawSocketICMPTest : public ::testing::Test { }; void RawSocketICMPTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); @@ -90,9 +94,10 @@ void RawSocketICMPTest::SetUp() { } void RawSocketICMPTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - EXPECT_THAT(close(s_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(s_), SyscallSucceeds()); + } } // We'll only read an echo in this case, as the kernel won't respond to the diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc index 352037c88..cde2f07c9 100644 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ b/test/syscalls/linux/raw_socket_ipv4.cc @@ -67,7 +67,11 @@ class RawSocketTest : public ::testing::TestWithParam { }; void RawSocketTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); @@ -79,9 +83,10 @@ void RawSocketTest::SetUp() { } void RawSocketTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - EXPECT_THAT(close(s_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(s_), SyscallSucceeds()); + } } // We should be able to create multiple raw sockets for the same protocol. -- cgit v1.2.3 From abbee5615f4480d8a41b4cf63839d2ab13b19abf Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 27 Sep 2019 14:12:35 -0700 Subject: Implement SO_BINDTODEVICE sockopt PiperOrigin-RevId: 271644926 --- pkg/sentry/socket/epsocket/epsocket.go | 20 ++ pkg/sentry/syscalls/linux/sys_socket.go | 2 +- pkg/tcpip/ports/ports.go | 114 ++++-- pkg/tcpip/ports/ports_test.go | 113 +++++- pkg/tcpip/stack/BUILD | 4 + pkg/tcpip/stack/nic.go | 15 +- pkg/tcpip/stack/stack.go | 58 +--- pkg/tcpip/stack/transport_demuxer.go | 227 +++++++----- pkg/tcpip/stack/transport_demuxer_test.go | 352 +++++++++++++++++++ pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/icmp/endpoint.go | 6 +- pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 56 ++- pkg/tcpip/transport/tcp/tcp_test.go | 116 +++++++ pkg/tcpip/transport/tcp/testing/context/context.go | 26 +- pkg/tcpip/transport/udp/BUILD | 1 + pkg/tcpip/transport/udp/endpoint.go | 42 ++- pkg/tcpip/transport/udp/forwarder.go | 2 +- pkg/tcpip/transport/udp/udp_test.go | 120 +++---- test/syscalls/linux/BUILD | 75 ++++ test/syscalls/linux/socket_bind_to_device.cc | 314 +++++++++++++++++ .../linux/socket_bind_to_device_distribution.cc | 381 +++++++++++++++++++++ .../linux/socket_bind_to_device_sequence.cc | 316 +++++++++++++++++ test/syscalls/linux/socket_bind_to_device_util.cc | 75 ++++ test/syscalls/linux/socket_bind_to_device_util.h | 67 ++++ test/syscalls/linux/uidgid.cc | 25 +- test/util/BUILD | 11 + test/util/uid_util.cc | 44 +++ test/util/uid_util.h | 29 ++ 30 files changed, 2308 insertions(+), 314 deletions(-) create mode 100644 pkg/tcpip/stack/transport_demuxer_test.go create mode 100644 test/syscalls/linux/socket_bind_to_device.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_distribution.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_sequence.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_util.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_util.h create mode 100644 test/util/uid_util.cc create mode 100644 test/util/uid_util.h (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 3e66f9cbb..5812085fa 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -942,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1305,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 3bac4d90d..b5a72ce63 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - if optLen <= 0 { + if optLen < 0 { return 0, nil, syserror.EINVAL } if optLen > maxOptLen { diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 315780c0c..40e202717 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -47,43 +47,76 @@ type portNode struct { refs int } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]portNode - -// isAvailable checks whether an IP address is available to bind to. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool { - if addr == anyIPAddress { - if len(b) == 0 { - return true - } +// deviceNode is never empty. When it has no elements, it is removed from the +// map that references it. +type deviceNode map[tcpip.NICID]portNode + +// isAvailable checks whether binding is possible by device. If not binding to a +// device, check against all portNodes. If binding to a specific device, check +// against the unspecified device and the provided device. +func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { + if bindToDevice == 0 { + // Trying to binding all devices. if !reuse { + // Can't bind because the (addr,port) is already bound. return false } - for _, n := range b { - if !n.reuse { + for _, p := range d { + if !p.reuse { + // Can't bind because the (addr,port) was previously bound without reuse. return false } } return true } - // If all addresses for this portDescriptor are already bound, no - // address is available. - if n, ok := b[anyIPAddress]; ok { - if !reuse { + if p, ok := d[0]; ok { + if !reuse || !p.reuse { return false } - if !n.reuse { + } + + if p, ok := d[bindToDevice]; ok { + if !reuse || !p.reuse { return false } } - if n, ok := b[addr]; ok { - if !reuse { + return true +} + +// bindAddresses is a set of IP addresses. +type bindAddresses map[tcpip.Address]deviceNode + +// isAvailable checks whether an IP address is available to bind to. If the +// address is the "any" address, check all other addresses. Otherwise, just +// check against the "any" address and the provided address. +func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { + if addr == anyIPAddress { + // If binding to the "any" address then check that there are no conflicts + // with all addresses. + for _, d := range b { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + return true + } + + // Check that there is no conflict with the "any" address. + if d, ok := b[anyIPAddress]; ok { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + + // Check that this is no conflict with the provided address. + if d, ok := b[addr]; ok { + if !d.isAvailable(reuse, bindToDevice) { return false } - return n.reuse } + return true } @@ -116,17 +149,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse) + return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse) { + if !addrs.isAvailable(addr, reuse, bindToDevice) { return false } } @@ -138,14 +171,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse) { + if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -153,13 +186,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil + return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { return false } @@ -171,11 +204,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m = make(bindAddresses) s.allocatedPorts[desc] = m } - if n, ok := m[addr]; ok { + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + if n, ok := d[bindToDevice]; ok { n.refs++ - m[addr] = n + d[bindToDevice] = n } else { - m[addr] = portNode{reuse: reuse, refs: 1} + d[bindToDevice] = portNode{reuse: reuse, refs: 1} } } @@ -184,22 +222,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { - n, ok := m[addr] + d, ok := m[addr] + if !ok { + continue + } + n, ok := d[bindToDevice] if !ok { continue } n.refs-- + d[bindToDevice] = n if n.refs == 0 { + delete(d, bindToDevice) + } + if len(d) == 0 { delete(m, addr) - } else { - m[addr] = n } if len(m) == 0 { delete(s.allocatedPorts, desc) diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 689401661..a67e283f1 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -34,6 +34,7 @@ type portReserveTestAction struct { want *tcpip.Error reuse bool release bool + device tcpip.NICID } func TestPortReservation(t *testing.T) { @@ -100,6 +101,112 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, release: true}, {port: 24, ip: anyIPAddress, reuse: false, want: nil}, }, + }, { + tname: "bind twice with device fails", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 3, want: nil}, + {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 1, want: nil}, + {port: 24, ip: fakeIPAddress, device: 2, want: nil}, + }, + }, { + tname: "bind to device and then without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + }, + }, { + tname: "binding with reuse and device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "mixing reuse and not reuse by binding to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: nil}, + }, + }, { + tname: "can't bind to 0 after mixing reuse and not reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind and release", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + + // Release the bind to device 0 and try again. + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + }, + }, { + tname: "bind twice with reuse once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "release an unreserved device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + // The below don't exist. + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + // Release all. + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -108,12 +215,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 28c49e8ff..3842f1f7d 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -54,6 +54,7 @@ go_test( size = "small", srcs = [ "stack_test.go", + "transport_demuxer_test.go", "transport_test.go", ], deps = [ @@ -64,6 +65,9 @@ go_test( "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0e8a23f00..f6106f762 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -34,8 +34,6 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - mu sync.RWMutex spoofing bool promiscuous bool @@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -707,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) if len(vv.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() @@ -723,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } @@ -767,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { return } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 18d1704a5..6a8079823 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1033,73 +1033,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { - if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { - if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) - } +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { - if nicID == 0 { - return s.demux.registerRawEndpoint(netProto, transProto, ep) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerRawEndpoint(netProto, transProto, ep) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { - if nicID == 0 { - s.demux.unregisterRawEndpoint(netProto, transProto, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterRawEndpoint(netProto, transProto, ep) - } + s.demux.unregisterRawEndpoint(netProto, transProto, ep) } // RegisterRestoredEndpoint records e as an endpoint that has been restored on diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cf8a6d129..8c768c299 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,25 +35,109 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]TransportEndpoint + endpoints map[TransportEndpointID]*endpointsByNic // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint } +type endpointsByNic struct { + mu sync.RWMutex + endpoints map[tcpip.NICID]*multiPortEndpoint + // seed is a random secret for a jenkins hash. + seed uint32 +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + } + + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints bound to the right device. + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + mpep.handlePacketAll(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[n.ID()] + if !ok { + mpep, ok = epsByNic.endpoints[0] + } + if !ok { + return + } + + // TODO(eyalsoha): Why don't we look at id to see if this packet needs to + // broadcast like we are doing with handlePacket above? + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) +} + +// registerEndpoint returns true if it succeeds. It fails and returns +// false if ep already has an element with the same key. +func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + + if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { + // There was already a bind. + return multiPortEp.singleRegisterEndpoint(t, reusePort) + } + + // This is a new binding. + multiPortEp := &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.reuse = reusePort + epsByNic.endpoints[bindToDevice] = multiPortEp + return multiPortEp.singleRegisterEndpoint(t, reusePort) +} + +// unregisterEndpoint returns true if endpointsByNic has to be unregistered. +func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + return false + } + if multiPortEp.unregisterEndpoint(t) { + delete(epsByNic.endpoints, bindToDevice) + } + return len(epsByNic.endpoints) == 0 +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - e, ok := eps.endpoints[id] + epsByNic, ok := eps.endpoints[id] if !ok { return } - if multiPortEp, ok := e.(*multiPortEndpoint); ok { - if !multiPortEp.unregisterEndpoint(ep) { - return - } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return } delete(eps.endpoints, id) } @@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]TransportEndpoint), + endpoints: make(map[TransportEndpointID]*endpointsByNic), } } } @@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) return err } } @@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // multiPortEndpoint is a container for TransportEndpoints which are bound to -// the same pair of address and port. +// the same pair of address and port. endpointsArr always has at least one +// element. type multiPortEndpoint struct { mu sync.RWMutex endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int - // seed is a random secret for a jenkins hash. - seed uint32 + // reuse indicates if more than one endpoint is allowed. + reuse bool } // reciprocalScale scales a value into range [0, n). @@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { - ep.mu.RLock() - defer ep.mu.RUnlock() +func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { + if len(mpep.endpointsArr) == 1 { + return mpep.endpointsArr[0] + } payload := []byte{ byte(id.LocalPort), @@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd byte(id.RemotePort >> 8), } - h := jenkins.Sum32(ep.seed) + h := jenkins.Sum32(seed) h.Write(payload) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) - return ep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) + return mpep.endpointsArr[idx] } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast or multicast datagram, deliver the datagram to all - // endpoints managed by ep. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break - } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + ep.mu.RLock() + for i, endpoint := range ep.endpointsArr { + // HandlePacket modifies vv, so each endpoint needs its own copy except for + // the final one. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, vv) + break } - } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + vvCopy := buffer.NewView(vv.Size()) + copy(vvCopy, vv.ToView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } + ep.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) -} - -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { +// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint +// list. The list might be empty already. +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - // A new endpoint is added into endpointsArr and its index there is - // saved in endpointsMap. This will allows to remove endpoint from - // the array fast. + if len(ep.endpointsArr) > 0 { + // If it was previously bound, we need to check if we can bind again. + if !ep.reuse || !reusePort { + return tcpip.ErrPortInUse + } + } + + // A new endpoint is added into endpointsArr and its index there is saved in + // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. @@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { return true } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { + // TODO(eyalsoha): Why? reusePort = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return nil + return tcpip.ErrUnknownProtocol } eps.mu.Lock() defer eps.mu.Unlock() - var multiPortEp *multiPortEndpoint - if _, ok := eps.endpoints[id]; ok { - if !reusePort { - return tcpip.ErrPortInUse - } - multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) - if !ok { - return tcpip.ErrPortInUse - } + if epsByNic, ok := eps.endpoints[id]; ok { + // There was already a binding. + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } - if reusePort { - if multiPortEp == nil { - multiPortEp = &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.seed = rand.Uint32() - eps.endpoints[id] = multiPortEp - } - - multiPortEp.singleRegisterEndpoint(ep) - - return nil + // This is a new binding. + epsByNic := &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), } - eps.endpoints[id] = ep + eps.endpoints[id] = epsByNic - return nil + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep) + eps.unregisterEndpoint(id, ep, bindToDevice) } } } @@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a broadcast, then find all matching transport endpoints. // Otherwise, try to find a single matching transport endpoint. - destEps := make([]TransportEndpoint, 0, 1) + destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { @@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.handlePacket(r, id, vv) } return true @@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, } // Deliver the packet. - ep.HandleControlPacket(id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, vv) return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { return ep diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go new file mode 100644 index 000000000..210233dc0 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -0,0 +1,352 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testPort = 4096 +) + +type testContext struct { + t *testing.T + linkEPs map[string]*channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } +} + +// newDualTestContextMultiNic creates the testing context and also linkEpNames +// named NICs. +func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + linkEPs := make(map[string]*channel.Endpoint) + for i, linkEpName := range linkEpNames { + channelEP := channel.New(256, mtu, "") + nicid := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + linkEPs[linkEpName] = channelEP + + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress IPv4 failed: %v", err) + } + + if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress IPv6 failed: %v", err) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEPs: linkEPs, + } +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) +} + +func TestTransportDemuxerRegister(t *testing.T) { + for _, test := range []struct { + name string + proto tcpip.NetworkProtocolNumber + want *tcpip.Error + }{ + {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"success", ipv4.ProtocolNumber, nil}, + } { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + } + }) + } +} + +// TestReuseBindToDevice injects varied packets on input devices and checks that +// the distribution of packets received matches expectations. +func TestDistribution(t *testing.T) { + type endpointSockopts struct { + reuse int + bindToDevice string + } + for _, test := range []struct { + name string + // endpoints will received the inject packets. + endpoints []endpointSockopts + // wantedDistribution is the wanted ratio of packets received on each + // endpoint for each NIC on which packets are injected. + wantedDistributions map[string][]float64 + }{ + { + "BindPortReuse", + // 5 endpoints that all have reuse set. + []endpointSockopts{ + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed evenly. + "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + }, + }, + { + "BindToDevice", + // 3 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{0, "dev0"}, + endpointSockopts{0, "dev1"}, + endpointSockopts{0, "dev2"}, + }, + map[string][]float64{ + // Injected packets on dev0 go only to the endpoint bound to dev0. + "dev0": []float64{1, 0, 0}, + // Injected packets on dev1 go only to the endpoint bound to dev1. + "dev1": []float64{0, 1, 0}, + // Injected packets on dev2 go only to the endpoint bound to dev2. + "dev2": []float64{0, 0, 1}, + }, + }, + { + "ReuseAndBindToDevice", + // 6 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed among endpoints bound to + // dev0. + "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + // Injected packets on dev1 get distributed among endpoints bound to + // dev1 or unbound. + "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + // Injected packets on dev999 go only to the unbound. + "dev999": []float64{0, 0, 0, 0, 0, 1}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for device, wantedDistribution := range test.wantedDistributions { + t.Run(device, func(t *testing.T) { + var devices []string + for d := range test.wantedDistributions { + devices = append(devices, d) + } + c := newDualTestContextMultiNic(t, defaultMTU, devices) + defer c.cleanup() + + c.createV6Endpoint(false) + + eps := make(map[tcpip.Endpoint]int) + + pollChannel := make(chan tcpip.Endpoint) + for i, endpoint := range test.endpoints { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + eps[ep] = i + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(ep) + + defer ep.Close() + reusePortOption := tcpip.ReusePortOption(endpoint.reuse) + if err := ep.SetSockOpt(reusePortOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + } + bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) + if err := ep.SetSockOpt(bindToDeviceOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + } + if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + } + } + + npackets := 100000 + nports := 10000 + if got, want := len(test.endpoints), len(wantedDistribution); got != want { + t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) + } + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, + &headers{ + srcPort: testPort + port, + dstPort: stackPort}, + device) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled by the same + // socket. + if want, got := ports[port], ep; want != got { + t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) + } + } + } + + // Check that a packet distribution is as expected. + for ep, i := range eps { + wantedRatio := wantedDistribution[i] + wantedRecv := wantedRatio * float64(npackets) + actualRecv := stats[ep] + actualRatio := float64(stats[ep]) / float64(npackets) + // The deviation is less than 10%. + if math.Abs(actualRatio-wantedRatio) > 0.05 { + t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 56e8a5d9b..842a16277 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -127,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -168,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, + false, /* reuse */ + 0, /* bindtoDevice */ ); err != nil { return err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c021c67ac..faaa4a4e3 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -495,6 +495,10 @@ type ReuseAddressOption int // to be bound to an identical socket address. type ReusePortOption int +// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets +// should bind only on a specific NIC. +type BindToDeviceOption string + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a111fdb2a..a3a910d41 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -104,7 +104,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -543,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 0802e984e..3ae4a5426 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil { n.Close() return nil, err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 35b489c68..a1cd0d481 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -280,6 +280,9 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -564,11 +567,11 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -625,12 +628,12 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -1060,6 +1063,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.QuickAckOption: if v == 0 { atomic.StoreUint32(&e.slowAck, 1) @@ -1260,6 +1278,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = "" + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1466,7 +1494,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if e.id.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1480,13 +1508,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if sameAddr && p == e.id.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + // reusePort is false below because connect cannot reuse a port even if + // reusePort was set. + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) { return false, nil } id := e.id id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: e.id = id return true, nil @@ -1504,7 +1534,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -1648,7 +1678,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil { return err } @@ -1729,7 +1759,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1739,16 +1769,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.id.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func() { + defer func(bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.id.LocalPort = 0 e.id.LocalAddress = "" e.boundNICID = 0 } - }() + }(e.bindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2be094876..089826a88 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -465,6 +465,66 @@ func TestSimpleReceive(t *testing.T) { ) } +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device string + want tcp.EndpointState + }{ + {"RightDevice", "nic1", tcp.StateEstablished}, + {"WrongDevice", "nic2", tcp.StateSynSent}, + {"AnyDevice", "", tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + }) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -2970,6 +3030,62 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) } +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } + + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) + } + + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s + } + + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) + } +} + func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d3f1d2cdf..ef823e4ae 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + if testing.Verbose() { + wep2 = sniffer.New(channel.New(1000, mtu, "")) + } + if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -588,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.Port = tcpHdr.SourcePort() } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -606,6 +609,15 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("SetSockOpt failed failed: %v", err) } } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) c.Connect(iss, rcvWnd, options) } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c1ca22b35..7a635ab8d 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -52,6 +52,7 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0bec7e62d..52f5af777 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -88,6 +88,7 @@ type endpoint struct { multicastNICID tcpip.NICID multicastLoop bool reusePort bool + bindToDevice tcpip.NICID broadcast bool // shutdownFlags represent the current shutdown state of the endpoint. @@ -144,8 +145,8 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) } for _, mem := range e.multicastMemberships { @@ -551,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.reusePort = v != 0 e.mu.Unlock() + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -646,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = tcpip.BindToDeviceOption("") + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -753,12 +779,12 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.id.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.id = id e.route.Release() e.route = stack.Route{} @@ -835,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) } e.id = id @@ -898,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { return id, err } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) } return id, err } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a9edc2c8d..2d0bc5221 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { ep.Close() return nil, err } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 2ec27be4d..5059ca22d 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -17,7 +17,6 @@ package udp_test import ( "bytes" "fmt" - "math" "math/rand" "testing" "time" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -476,87 +476,59 @@ func newMinPayload(minSize int) []byte { return b } -func TestBindPortReuse(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - var eps [5]tcpip.Endpoint - reusePortOpt := tcpip.ReusePortOption(1) - - pollChannel := make(chan tcpip.Endpoint) - for i := 0; i < len(eps); i++ { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(eps[i]) +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - defer eps[i].Close() - if err := eps[i].SetSockOpt(reusePortOpt); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) } + defer ep.Close() - npackets := 100000 - nports := 10000 - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.injectV6Packet(payload, &header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - }) + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled - // by the same socket. - if ports[port] != ep { - t.Fatalf("Port mismatch") - } - } + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) } - if len(stats) != len(eps) { - t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s } - // Check that a packet distribution is fair between sockets. - for _, c := range stats { - n := float64(npackets) / float64(len(eps)) - // The deviation is less than 10%. - if math.Abs(float64(c)-n) > n/10 { - t.Fatal(c, n) - } + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) } } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 28b23ce58..e645eebfa 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2463,6 +2463,63 @@ cc_binary( ], ) +cc_binary( + name = "socket_bind_to_device_test", + testonly = 1, + srcs = [ + "socket_bind_to_device.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_sequence_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_sequence.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_distribution_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_distribution.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + cc_binary( name = "socket_ip_udp_loopback_non_blocking_test", testonly = 1, @@ -2740,6 +2797,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "socket_bind_to_device_util", + testonly = 1, + srcs = [ + "socket_bind_to_device_util.cc", + ], + hdrs = [ + "socket_bind_to_device_util.h", + ], + deps = [ + "//test/util:test_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_binary( name = "socket_stream_local_test", testonly = 1, @@ -3253,6 +3327,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:uid_util", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc new file mode 100644 index 000000000..d20821cac --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device.cc @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +// Test fixture for SO_BINDTODEVICE tests. +class BindToDeviceTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + + interface_name_ = "eth1"; + auto interface_names = GetInterfaceNames(); + if (interface_names.find(interface_name_) == interface_names.end()) { + // Need a tunnel. + tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()); + interface_name_ = tunnel_->GetName(); + ASSERT_FALSE(interface_name_.empty()); + } + socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create()); + } + + string interface_name() const { return interface_name_; } + + int socket_fd() const { return socket_->get(); } + + private: + std::unique_ptr tunnel_; + string interface_name_; + std::unique_ptr socket_; +}; + +constexpr char kIllegalIfnameChar = '/'; + +// Tests getsockopt of the default value. +TEST_P(BindToDeviceTest, GetsockoptDefault) { + char name_buffer[IFNAMSIZ * 2]; + char original_name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Read the default SO_BINDTODEVICE. + memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(name_buffer_size, 0); + EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)), + 0); + } +} + +// Tests setsockopt of invalid device name. +TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Set an invalid device name. + memset(name_buffer, kIllegalIfnameChar, 5); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallFailsWithErrno(ENODEV)); +} + +// Tests setsockopt of a buffer with a valid device name but not +// null-terminated, with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Intentionally overwrite the null at the end. + memset(name_buffer + interface_name().size(), kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size()); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is exactly right. + if (name_buffer_size == interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests setsockopt of a buffer with a valid device name and null-terminated, +// with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Don't overwrite the null at the end. + memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size() - 1); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is at least the right size. + if (name_buffer_size >= interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests that setsockopt of an invalid device name doesn't unset the previous +// valid setsockopt. +TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Write unsuccessfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallFailsWithErrno(ENODEV)); + + // Read it back successfully, it's unchanged. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); +} + +// Tests that setsockopt of zero-length string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClear) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + name_buffer_size = 0; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests that setsockopt of empty string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer[0] = 0; + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests getsockopt with different buffer sizes. +TEST_P(BindToDeviceTest, GetsockoptDevice) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back at various buffer sizes. + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice + // for this interface name. + if (name_buffer_size >= IFNAMSIZ) { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + } else { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallFailsWithErrno(EINVAL)); + EXPECT_EQ(name_buffer_size, i); + } + } +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc new file mode 100644 index 000000000..4d2400328 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -0,0 +1,381 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +struct EndpointConfig { + std::string bind_to_device; + double expected_ratio; +}; + +struct DistributionTestCase { + std::string name; + std::vector endpoints; +}; + +struct ListenerConnector { + TestAddress listener; + TestAddress connector; +}; + +// Test fixture for SO_BINDTODEVICE tests the distribution of packets received +// with varying SO_BINDTODEVICE settings. +class BindToDeviceDistributionTest + : public ::testing::TestWithParam< + ::testing::tuple> { + protected: + void SetUp() override { + printf("Testing case: %s, listener=%s, connector=%s\n", + ::testing::get<1>(GetParam()).name.c_str(), + ::testing::get<0>(GetParam()).listener.description.c_str(), + ::testing::get<0>(GetParam()).connector.description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + } +}; + +PosixErrorOr AddrPort(int family, sockaddr_storage const& addr) { + switch (family) { + case AF_INET: + return static_cast( + reinterpret_cast(&addr)->sin_port); + case AF_INET6: + return static_cast( + reinterpret_cast(&addr)->sin6_port); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { + switch (family) { + case AF_INET: + reinterpret_cast(addr)->sin_port = port; + return NoError(); + case AF_INET6: + reinterpret_cast(addr)->sin6_port = port; + return NoError(); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +// Binds sockets to different devices and then creates many TCP connections. +// Checks that the distribution of connections received on the sockets matches +// the expectation. +TEST_P(BindToDeviceDistributionTest, Tcp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening sockets. + std::vector listener_fds; + std::vector> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic connects_received = ATOMIC_VAR_INIT(0); + std::vector accept_counts(listener_fds.size(), 0); + std::vector> listen_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + listen_threads[i] = absl::make_unique( + [&listener_fds, &accept_counts, &connects_received, i, + kConnectAttempts]() { + do { + auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); + if (!fd.ok()) { + // Another thread has shutdown our read side causing the accept to + // fail. + ASSERT_GE(connects_received, kConnectAttempts) + << "errno = " << fd.error(); + return; + } + // Receive some data from a socket to be sure that the connect() + // system call has been completed on another side. + int data; + EXPECT_THAT( + RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + accept_counts[i]++; + } while (++connects_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + RetryEINTR(connect)(fd.get(), reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& listen_thread : listen_threads) { + listen_thread->Join(); + } + // Check that connections are distributed correctly among listening sockets. + for (int i = 0; i < accept_counts.size(); i++) { + EXPECT_THAT( + accept_counts[i], + EquivalentWithin(static_cast(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +// Binds sockets to different devices and then sends many UDP packets. Checks +// that the distribution of packets received on the sockets matches the +// expectation. +TEST_P(BindToDeviceDistributionTest, Udp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening socket. + std::vector listener_fds; + std::vector> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic packets_received = ATOMIC_VAR_INIT(0); + std::vector packets_per_socket(listener_fds.size(), 0); + std::vector> receiver_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + receiver_threads[i] = absl::make_unique( + [&listener_fds, &packets_per_socket, &packets_received, i]() { + do { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + + auto ret = RetryEINTR(recvfrom)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast(&addr), &addrlen); + + if (packets_received < kConnectAttempts) { + ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); + } + + if (ret != sizeof(data)) { + // Another thread may have shutdown our read side causing the + // recvfrom to fail. + break; + } + + packets_received++; + packets_per_socket[i]++; + + // A response is required to synchronize with the main thread, + // otherwise the main thread can send more than can fit into receive + // queues. + EXPECT_THAT(RetryEINTR(sendto)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); + } while (packets_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + int data; + EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& receiver_thread : receiver_threads) { + receiver_thread->Join(); + } + // Check that packets are distributed correctly among listening sockets. + for (int i = 0; i < packets_per_socket.size(); i++) { + EXPECT_THAT( + packets_per_socket[i], + EquivalentWithin(static_cast(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +std::vector GetDistributionTestCases() { + return std::vector{ + {"Even distribution among sockets not bound to device", + {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}}, + {"Sockets bound to other interfaces get no packets", + {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}}, + {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}}, + {"Even distribution among sockets bound to device", + {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}}, + }; +} + +INSTANTIATE_TEST_SUITE_P( + BindToDeviceTest, BindToDeviceDistributionTest, + ::testing::Combine(::testing::Values( + // Listeners bound to IPv4 addresses refuse + // connections using IPv6 addresses. + ListenerConnector{V4Any(), V4Loopback()}, + ListenerConnector{V4Loopback(), V4MappedLoopback()}), + ::testing::ValuesIn(GetDistributionTestCases()))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc new file mode 100644 index 000000000..a7365d139 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -0,0 +1,316 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket +// binding. +class BindToDeviceSequenceTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + socket_factory_ = GetParam(); + + interface_names_ = GetInterfaceNames(); + } + + PosixErrorOr> NewSocket() const { + return socket_factory_.Create(); + } + + // Gets a device by device_id. If the device_id has been seen before, returns + // the previously returned device. If not, finds or creates a new device. + // Returns an empty string on failure. + void GetDevice(int device_id, string *device_name) { + auto device = devices_.find(device_id); + if (device != devices_.end()) { + *device_name = device->second; + return; + } + + // Need to pick a new device. Try ethernet first. + *device_name = absl::StrCat("eth", next_unused_eth_); + if (interface_names_.find(*device_name) != interface_names_.end()) { + devices_[device_id] = *device_name; + next_unused_eth_++; + return; + } + + // Need to make a new tunnel device. gVisor tests should have enough + // ethernet devices to never reach here. + ASSERT_FALSE(IsRunningOnGvisor()); + // Need a tunnel. + tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New())); + devices_[device_id] = tunnels_.back()->GetName(); + *device_name = devices_[device_id]; + } + + // Release the socket + void ReleaseSocket(int socket_id) { + // Close the socket that was made in a previous action. The socket_id + // indicates which socket to close based on index into the list of actions. + sockets_to_close_.erase(socket_id); + } + + // Bind a socket with the reuse option and bind_to_device options. Checks + // that all steps succeed and that the bind command's error matches want. + // Sets the socket_id to uniquely identify the socket bound if it is not + // nullptr. + void BindSocket(bool reuse, int device_id = 0, int want = 0, + int *socket_id = nullptr) { + next_socket_id_++; + sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket_fd = sockets_to_close_[next_socket_id_]->get(); + if (socket_id != nullptr) { + *socket_id = next_socket_id_; + } + + // If reuse is indicated, do that. + if (reuse) { + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + } + + // If the device is non-zero, bind to that device. + if (device_id != 0) { + string device_name; + ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, + device_name.c_str(), device_name.size() + 1), + SyscallSucceedsWithValue(0)); + char get_device[100]; + socklen_t get_device_size = 100; + EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device, + &get_device_size), + SyscallSucceedsWithValue(0)); + } + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr.sin_port = port_; + if (want == 0) { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast(&addr), + sizeof(addr)), + SyscallSucceeds()); + } else { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast(&addr), + sizeof(addr)), + SyscallFailsWithErrno(want)); + } + + if (port_ == 0) { + // We don't yet know what port we'll be using so we need to fetch it and + // remember it for future commands. + socklen_t addr_size = sizeof(addr); + ASSERT_THAT( + getsockname(socket_fd, reinterpret_cast(&addr), + &addr_size), + SyscallSucceeds()); + port_ = addr.sin_port; + } + } + + private: + SocketKind socket_factory_; + // devices maps from the device id in the test case to the name of the device. + std::unordered_map devices_; + // These are the tunnels that were created for the test and will be destroyed + // by the destructor. + vector> tunnels_; + // A list of all interface names before the test started. + std::unordered_set interface_names_; + // The next ethernet device to use when requested a device. + int next_unused_eth_ = 1; + // The port for all tests. Originally 0 (any) and later set to the port that + // all further commands will use. + in_port_t port_ = 0; + // sockets_to_close_ is a map from action index to the socket that was + // created. + std::unordered_map> + sockets_to_close_; + int next_socket_id_ = 0; +}; + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 1)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 2)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithReuse) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0)); +} + +TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0)); +} + +TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindAndRelease) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + int to_release; + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + // Release the bind to device 0 and try again. + ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345)); +} + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc new file mode 100644 index 000000000..f4ee775bd --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.cc @@ -0,0 +1,75 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_bind_to_device_util.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +PosixErrorOr> Tunnel::New(string tunnel_name) { + int fd; + RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR)); + + // Using `new` to access a non-public constructor. + auto new_tunnel = absl::WrapUnique(new Tunnel(fd)); + + ifreq ifr = {}; + ifr.ifr_flags = IFF_TUN; + strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name)); + + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr)); + new_tunnel->name_ = ifr.ifr_name; + return new_tunnel; +} + +std::unordered_set GetInterfaceNames() { + struct if_nameindex* interfaces = if_nameindex(); + std::unordered_set names; + if (interfaces == nullptr) { + return names; + } + for (auto interface = interfaces; + interface->if_index != 0 || interface->if_name != nullptr; interface++) { + names.insert(interface->if_name); + } + if_freenameindex(interfaces); + return names; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h new file mode 100644 index 000000000..f941ccc86 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.h @@ -0,0 +1,67 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ +#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +class Tunnel { + public: + static PosixErrorOr> New( + std::string tunnel_name = ""); + const std::string& GetName() const { return name_; } + + ~Tunnel() { + if (fd_ != -1) { + close(fd_); + } + } + + private: + Tunnel(int fd) : fd_(fd) {} + int fd_ = -1; + std::string name_; +}; + +std::unordered_set GetInterfaceNames(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index d48453a93..6218fbce1 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -25,6 +25,7 @@ #include "test/util/posix_error.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/uid_util.h" ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); @@ -68,30 +69,6 @@ TEST(UidGidTest, Getgroups) { // here; see the setgroups test below. } -// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns -// true. Otherwise IsRoot logs an explanatory message and returns false. -PosixErrorOr IsRoot() { - uid_t ruid, euid, suid; - int rc = getresuid(&ruid, &euid, &suid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresuid"); - } - if (ruid != 0 || euid != 0 || suid != 0) { - return false; - } - gid_t rgid, egid, sgid; - rc = getresgid(&rgid, &egid, &sgid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresgid"); - } - if (rgid != 0 || egid != 0 || sgid != 0) { - return false; - } - return true; -} - // Checks that the calling process' real/effective/saved user IDs are // ruid/euid/suid respectively. PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) { diff --git a/test/util/BUILD b/test/util/BUILD index 25ed9c944..5d2a9cc2c 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -324,3 +324,14 @@ cc_library( ":test_util", ], ) + +cc_library( + name = "uid_util", + testonly = 1, + srcs = ["uid_util.cc"], + hdrs = ["uid_util.h"], + deps = [ + ":posix_error", + ":save_util", + ], +) diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc new file mode 100644 index 000000000..b131b4b99 --- /dev/null +++ b/test/util/uid_util.cc @@ -0,0 +1,44 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/posix_error.h" +#include "test/util/save_util.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr IsRoot() { + uid_t ruid, euid, suid; + int rc = getresuid(&ruid, &euid, &suid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresuid"); + } + if (ruid != 0 || euid != 0 || suid != 0) { + return false; + } + gid_t rgid, egid, sgid; + rc = getresgid(&rgid, &egid, &sgid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresgid"); + } + if (rgid != 0 || egid != 0 || sgid != 0) { + return false; + } + return true; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/uid_util.h b/test/util/uid_util.h new file mode 100644 index 000000000..2cd387fb0 --- /dev/null +++ b/test/util/uid_util.h @@ -0,0 +1,29 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_ + +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +// Returns true if the caller's real/effective/saved user/group IDs are all 0. +PosixErrorOr IsRoot(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_ -- cgit v1.2.3