diff options
144 files changed, 1675 insertions, 1558 deletions
@@ -107,7 +107,7 @@ analyzers: - pkg/sentry/socket/unix/transport/connectioned.go # unsupported usage. - pkg/sentry/vfs/dentry.go # unsupported usage. - pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go # unsupported usage. - - pkg/tcpip/stack/conntrack.go # unsupported usage. + - pkg/tcpip/stack/conntrack.go # unsupported usage. - pkg/tcpip/transport/packet/endpoint_state.go # unsupported usage. - pkg/tcpip/transport/raw/endpoint_state.go # unsupported usage. - pkg/tcpip/transport/icmp/endpoint.go # unsupported usage. @@ -177,6 +177,7 @@ analyzers: - pkg/sentry/platform/kvm/bluepill_unsafe.go # Special case. - pkg/sentry/platform/kvm/machine_unsafe.go # Special case. - pkg/sentry/platform/safecopy/safecopy_unsafe.go # Special case. + - pkg/sentry/usage/memory_unsafe.go # Special case. - pkg/sentry/vfs/mount_unsafe.go # Special case. - pkg/state/decode_unsafe.go # Special case. unusedresult: diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD index 8afc9688c..201727780 100644 --- a/pkg/errors/linuxerr/BUILD +++ b/pkg/errors/linuxerr/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux/errno", "//pkg/errors", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go index bbdcdecd0..9246f2e89 100644 --- a/pkg/errors/linuxerr/linuxerr.go +++ b/pkg/errors/linuxerr/linuxerr.go @@ -20,6 +20,7 @@ package linuxerr import ( "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/errors" ) @@ -325,3 +326,22 @@ func ErrorFromErrno(e errno.Errno) *errors.Error { } panic(fmt.Sprintf("invalid error requested with errno: %d", e)) } + +// Equals compars a linuxerr to a given error +// TODO(b/34162363): Remove when syserror is removed. +func Equals(e *errors.Error, err error) bool { + if err == nil { + return e == NOERROR || e == nil + } + if e == nil { + return err == NOERROR || err == unix.Errno(0) + } + + switch err.(type) { + case *errors.Error: + return e == err + case unix.Errno, error: + return unix.Errno(e.Errno()) == err + } + return false +} diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go index a81dd9560..62743c338 100644 --- a/pkg/errors/linuxerr/linuxerr_test.go +++ b/pkg/errors/linuxerr/linuxerr_test.go @@ -16,6 +16,8 @@ package syserror_test import ( "errors" + "io" + "io/fs" "syscall" "testing" @@ -243,3 +245,62 @@ func TestSyscallErrnoToErrors(t *testing.T) { }) } } + +// TestEqualsMethod tests that the Equals method correctly compares syerror, +// unix.Errno and linuxerr. +// TODO (b/34162363): Remove this. +func TestEqualsMethod(t *testing.T) { + for _, tc := range []struct { + name string + linuxErr []*gErrors.Error + err []error + equal bool + }{ + { + name: "compare nil", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, + err: []error{nil, linuxerr.NOERROR, unix.Errno(0)}, + equal: true, + }, + { + name: "linuxerr nil error not", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, + err: []error{unix.Errno(1), linuxerr.EPERM, syserror.EACCES}, + equal: false, + }, + { + name: "linuxerr not nil error nil", + linuxErr: []*gErrors.Error{linuxerr.ENOENT}, + err: []error{nil, unix.Errno(0), linuxerr.NOERROR}, + equal: false, + }, + { + name: "equal errors", + linuxErr: []*gErrors.Error{linuxerr.ESRCH}, + err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, + equal: true, + }, + { + name: "unequal errors", + linuxErr: []*gErrors.Error{linuxerr.ENOENT}, + err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, + equal: false, + }, + { + name: "other error", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR, linuxerr.E2BIG, linuxerr.EINVAL}, + err: []error{fs.ErrInvalid, io.EOF}, + equal: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + for _, le := range tc.linuxErr { + for _, e := range tc.err { + if linuxerr.Equals(le, e) != tc.equal { + t.Fatalf("Expected %t from Equals method for linuxerr: %s %T and error: %s %T", tc.equal, le, le, e, e) + } + } + } + }) + } +} diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 9730b88c1..c810c7946 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -10,9 +10,7 @@ go_library( "flipcall_unsafe.go", "futex_linux.go", "io.go", - "packet_window_allocator.go", - "packet_window_mmap_amd64.go", - "packet_window_mmap_arm64.go", + "packet_window.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 8d8309a73..f0e4ff487 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -22,6 +22,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/memutil" ) // An Endpoint provides the ability to synchronously transfer data and control @@ -96,9 +97,9 @@ func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ... if pwd.Length > math.MaxUint32 { return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32) } - m, e := packetWindowMmap(pwd) - if e != 0 { - return fmt.Errorf("failed to mmap packet window: %v", e) + m, err := memutil.MapFile(0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) + if err != nil { + return fmt.Errorf("failed to mmap packet window: %v", err) } ep.packet = m ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes) diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window.go index 9122c97b7..9122c97b7 100644 --- a/pkg/flipcall/packet_window_allocator.go +++ b/pkg/flipcall/packet_window.go diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD index 9d07d98b4..bea595286 100644 --- a/pkg/memutil/BUILD +++ b/pkg/memutil/BUILD @@ -4,7 +4,11 @@ package(licenses = ["notice"]) go_library( name = "memutil", - srcs = ["memutil_unsafe.go"], + srcs = [ + "memfd_linux_unsafe.go", + "memutil.go", + "mmap.go", + ], visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) diff --git a/pkg/memutil/memutil_unsafe.go b/pkg/memutil/memfd_linux_unsafe.go index 6676d1ce3..504382213 100644 --- a/pkg/memutil/memutil_unsafe.go +++ b/pkg/memutil/memfd_linux_unsafe.go @@ -14,7 +14,6 @@ // +build linux -// Package memutil provides a wrapper for the memfd_create() system call. package memutil import ( diff --git a/pkg/flipcall/packet_window_mmap_arm64.go b/pkg/memutil/memutil.go index 87ad1a4a1..3185882fd 100644 --- a/pkg/flipcall/packet_window_mmap_arm64.go +++ b/pkg/memutil/memutil.go @@ -1,4 +1,4 @@ -// Copyright 2020 The gVisor Authors. +// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build arm64 - -package flipcall - -import "golang.org/x/sys/unix" - -// Return a memory mapping of the pwd in memory that can be shared outside the sandbox. -func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, unix.Errno) { - m, _, err := unix.RawSyscall6(unix.SYS_MMAP, 0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) - return m, err -} +// Package memutil provides utilities for working with shared memory files. +package memutil diff --git a/pkg/flipcall/packet_window_mmap_amd64.go b/pkg/memutil/mmap.go index ced587a2a..7c939293f 100644 --- a/pkg/flipcall/packet_window_mmap_amd64.go +++ b/pkg/memutil/mmap.go @@ -12,12 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package flipcall +package memutil -import "golang.org/x/sys/unix" +import ( + "golang.org/x/sys/unix" +) -// Return a memory mapping of the pwd in memory that can be shared outside the sandbox. -func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, unix.Errno) { - m, _, err := unix.RawSyscall6(unix.SYS_MMAP, 0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) - return m, err +// MapFile returns a memory mapping configured by the given options as per +// mmap(2). +func MapFile(addr, len, prot, flags, fd, offset uintptr) (uintptr, error) { + m, _, e := unix.RawSyscall6(unix.SYS_MMAP, addr, len, prot, flags, fd, offset) + if e != 0 { + return 0, e + } + return m, nil } diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 0dc100f9b..74adbfa55 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/abi/linux", "//pkg/amutex", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/p9", @@ -110,6 +111,7 @@ go_test( deps = [ ":fs", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 5aa668873..ae282d14e 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -410,7 +411,7 @@ func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error return err } lowerXattr, err := lower.ListXattr(ctx, linux.XATTR_SIZE_MAX) - if err != nil && err != syserror.EOPNOTSUPP { + if err != nil && !linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { return err } diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 9d5d40954..e45749be6 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" @@ -1439,7 +1440,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // replaced is the dirent that is being overwritten by rename. replaced, err := newParent.walk(ctx, root, newName, false /* may unlock */) if err != nil { - if err != syserror.ENOENT { + if !linuxerr.Equals(linuxerr.ENOENT, err) { return err } diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 2120f2bad..7fc53ed22 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -13,6 +13,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/log", diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 757b7d511..f8a29816b 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" @@ -158,7 +159,7 @@ func (p *pipeOperations) Write(ctx context.Context, file *fs.File, src usermem.I // isBlockError unwraps os errors and checks if they are caused by EAGAIN or // EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { - if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK { + if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) { return true } if pe, ok := err.(*os.PathError); ok { diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 94cb05246..c08301d19 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -26,6 +26,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 940838a44..1a6f353d0 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/device" @@ -66,7 +67,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string // Get a p9.File for name. qids, newFile, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name}) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { if cp.cacheNegativeDirents() { // Return a negative Dirent. It will stay cached until something // is created over it. @@ -298,7 +299,7 @@ func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name st // N.B. FIFOs use major/minor numbers 0. if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil { - if i.session().overrides == nil || err != syserror.EPERM { + if i.session().overrides == nil || !linuxerr.Equals(linuxerr.EPERM, err) { return err } // If gofer doesn't support mknod, check if we can create an internal fifo. diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 3c45f6cc5..52de3875d 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/iovec", diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 46a2dc47d..225244868 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/refs" @@ -213,7 +214,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // block (and only for stream sockets). err = syserror.EAGAIN } - if n > 0 && err != syserror.EAGAIN { + if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) { // The caller may need to block to send more data, but // otherwise there isn't anything that can be done about an // error with a partial write. diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 1183727ab..77613bfd5 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -191,7 +192,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from // tty_check_change() to -ENOTTY. - if err == syserror.EIO { + if linuxerr.Equals(linuxerr.EIO, err) { return 0, syserror.ENOTTY } return 0, err diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go index ab74724a3..e7db79189 100644 --- a/pkg/sentry/fs/host/util.go +++ b/pkg/sentry/fs/host/util.go @@ -19,12 +19,12 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/syserror" ) func nodeType(s *unix.Stat_t) fs.InodeType { @@ -98,7 +98,7 @@ type dirInfo struct { // isBlockError unwraps os errors and checks if they are caused by EAGAIN or // EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { - if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK { + if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) { return true } if pe, ok := err.(*os.PathError); ok { diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index e97afc626..bd1125dcc 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" @@ -71,7 +72,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // A file could have been created over a whiteout, so we need to // check if something exists in the upper file system first. child, err := parent.upper.Lookup(ctx, name) - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { // We encountered an error that an overlay cannot handle, // we must propagate it to the caller. parent.copyMu.RUnlock() @@ -125,7 +126,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // Check the lower file system. child, err := parent.lower.Lookup(ctx, name) // Same song and dance as above. - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { // Don't leak resources. if upperInode != nil { upperInode.DecRef(ctx) @@ -396,7 +397,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena // newName has been removed out from under us. That's fine; // filesystems where that can happen must handle stale // 'replaced'. - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } if err == nil { diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go index aa9851b26..cc5ffa6f1 100644 --- a/pkg/sentry/fs/inode_overlay_test.go +++ b/pkg/sentry/fs/inode_overlay_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" @@ -191,11 +192,11 @@ func TestLookup(t *testing.T) { } { t.Run(test.desc, func(t *testing.T) { dirent, err := test.dir.Lookup(ctx, test.name) - if test.found && (err == syserror.ENOENT || dirent.IsNegative()) { + if test.found && (linuxerr.Equals(linuxerr.ENOENT, err) || dirent.IsNegative()) { t.Fatalf("lookup %q expected to find positive dirent, got dirent %v err %v", test.name, dirent, err) } if !test.found { - if err != syserror.ENOENT && !dirent.IsNegative() { + if !linuxerr.Equals(linuxerr.ENOENT, err) && !dirent.IsNegative() { t.Errorf("lookup %q expected to return ENOENT or negative dirent, got dirent %v err %v", test.name, dirent, err) } // Nothing more to check. diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 7af7e0b45..e6d74b949 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -30,6 +30,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 91c35eea9..187e9a921 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -34,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -291,7 +291,7 @@ func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s continue } if err := n.s.Statistics(stat, line.prefix); err != nil { - if err == syserror.EOPNOTSUPP { + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) } else { log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index 33da82868..ca9f645f6 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -139,7 +140,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, // Attempt to do a WriteTo; this is likely the most efficient. n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) - if n == 0 && err == syserror.ENOSYS && !opts.Dup { + if n == 0 && linuxerr.Equals(linuxerr.ENOSYS, err) && !opts.Dup { // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be // more efficient than a copy if buffers are cached or readily // available. (It's unlikely that they can actually be donated). @@ -151,7 +152,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, // if we block at some point, we could lose data. If the source is // not a pipe then reading is not destructive; if the destination // is a regular file, then it is guaranteed not to block writing. - if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) { + if n == 0 && linuxerr.Equals(linuxerr.ENOSYS, err) && !opts.Dup && (!dstPipe || !srcPipe) { // Fallback to an in-kernel copy. n, err = io.Copy(w, &io.LimitedReader{ R: r, diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD index 66e949c95..4acc73ee0 100644 --- a/pkg/sentry/fs/user/BUILD +++ b/pkg/sentry/fs/user/BUILD @@ -12,6 +12,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/log", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go index 124bc95ed..f6eaab2bd 100644 --- a/pkg/sentry/fs/user/path.go +++ b/pkg/sentry/fs/user/path.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -93,7 +94,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s binPath := path.Join(p, name) traversals := uint(linux.MaxSymlinkTraversals) d, err := mns.FindInode(ctx, root, nil, binPath, &traversals) - if err == syserror.ENOENT || err == syserror.EACCES { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.EACCES, err) { // Didn't find it here. continue } @@ -142,7 +143,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam Flags: linux.O_RDONLY, } dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts) - if err == syserror.ENOENT || err == syserror.EACCES { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.EACCES, err) { // Didn't find it here. continue } diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 2dbc6bfd5..1060b5301 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -88,13 +88,13 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/marshal/primitive", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/test/testutil", "//pkg/usermem", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index d9fd4590c..db712e71f 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -26,12 +26,12 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/pkg/usermem" ) @@ -173,7 +173,7 @@ func TestSeek(t *testing.T) { } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("expected error EINVAL but got %v", err) } @@ -187,7 +187,7 @@ func TestSeek(t *testing.T) { } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("expected error EINVAL but got %v", err) } @@ -204,7 +204,7 @@ func TestSeek(t *testing.T) { } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("expected error EINVAL but got %v", err) } } diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 3a4777fbe..871df5984 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -46,6 +46,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/marshal", @@ -76,6 +77,7 @@ go_test( library = ":fuse", deps = [ "//pkg/abi/linux", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal", "//pkg/sentry/fsimpl/testutil", diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 78ea6a31e..1fddd858e 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -19,9 +19,9 @@ import ( "testing" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" ) // TestConnectionInitBlock tests if initialization @@ -104,7 +104,7 @@ func TestConnectionAbort(t *testing.T) { // After abort, Call() should return directly with ENOTCONN. req := conn.NewRequest(creds, 0, 0, 0, testObj) _, err = conn.Call(task, req) - if err != syserror.ENOTCONN { + if !linuxerr.Equals(linuxerr.ENOTCONN, err) { t.Fatalf("Incorrect error code received for Call() after connection aborted") } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 167c899e2..47794810c 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -440,7 +441,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr if err != nil { return nil, err } - if err := res.Error(); err == syserror.ENOSYS && !isDir { + if err := res.Error(); linuxerr.Equals(linuxerr.ENOSYS, err) && !isDir { i.fs.conn.noOpen = true } else if err != nil { return nil, err diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 368272f12..752060044 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -49,6 +49,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index eb09d54c3..af2b773c3 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" @@ -255,7 +256,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { parent.cacheNegativeLookupLocked(name) } return nil, err @@ -382,7 +383,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.EEXIST } checkExistence := func() error { - if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT { + if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } else if child != nil { return syserror.EEXIST @@ -715,7 +716,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v mode |= linux.S_ISGID } if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil { - if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { + if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) { return err } ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err) @@ -752,7 +753,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { creds := rp.Credentials() _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - if err != syserror.EPERM { + if !linuxerr.Equals(linuxerr.EPERM, err) { return err } @@ -765,7 +766,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v case err == nil: // Step succeeded, another file exists. return syserror.EEXIST - case err != syserror.ENOENT: + case !linuxerr.Equals(linuxerr.ENOENT, err): // Unexpected error. return err } @@ -862,7 +863,7 @@ afterTrailingSymlink: // Determine whether or not we need to create a file. parent.dirMu.Lock() child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err == syserror.ENOENT && mayCreate { + if linuxerr.Equals(linuxerr.ENOENT, err) && mayCreate { if parent.isSynthetic() { parent.dirMu.Unlock() return nil, syserror.EPERM @@ -1033,7 +1034,7 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs. retry: h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) if err != nil { - if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO { + if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) { // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails // with ENXIO if opening the same named pipe with O_WRONLY would // block because there are no readers of the pipe. @@ -1284,7 +1285,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return syserror.ENOENT } replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } var replacedVFSD *vfs.Dentry diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index cf69e1b7a..496e31e34 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -46,6 +46,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" @@ -1763,7 +1764,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool openReadable := !d.readFile.isNil() || read openWritable := !d.writeFile.isNil() || write h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc) - if err == syserror.EACCES && (openReadable != read || openWritable != write) { + if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) { // It may not be possible to use a single handle for both // reading and writing, since permissions on the file may have // changed to e.g. disallow reading after previously being diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index c7bf10007..398288ee3 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -78,7 +79,7 @@ func nonblockingPipeHasWriter(fd int32) (bool, error) { defer tempPipeMu.Unlock() // Copy 1 byte from fd into the temporary pipe. n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK) - if err == syserror.EAGAIN { + if linuxerr.Equals(linuxerr.EAGAIN, err) { // The pipe represented by fd is empty, but has a writer. return true, nil } diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go index 83e841a51..e67422a2f 100644 --- a/pkg/sentry/fsimpl/gofer/save_restore.go +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -21,13 +21,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) type saveRestoreContextID int @@ -92,7 +92,7 @@ func (fd *specialFileFD) savePipeData(ctx context.Context) error { fd.buf = append(fd.buf, buf[:n]...) } if err != nil { - if err == io.EOF || err == syserror.EAGAIN { + if err == io.EOF || linuxerr.Equals(linuxerr.EAGAIN, err) { break } return err diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index c12444b7e..3d7b5506e 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" @@ -228,7 +229,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // Just buffer the read instead. buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) - if err == syserror.EAGAIN { + if linuxerr.Equals(linuxerr.EAGAIN, err) { err = syserror.ErrWouldBlock } if n == 0 { @@ -316,7 +317,7 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off return 0, offset, copyErr } n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset)) - if err == syserror.EAGAIN { + if linuxerr.Equals(linuxerr.EAGAIN, err) { err = syserror.ErrWouldBlock } // Update offset if the offset is valid. diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index b94dfeb7f..f2f83796c 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -45,6 +45,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fdnotifier", "//pkg/fspath", "//pkg/hostarch", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index a81f550b1..2dbfbdecf 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -24,6 +24,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" @@ -109,7 +110,7 @@ type inode struct { func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) { // Determine if hostFD is seekable. _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) - seekable := err != syserror.ESPIPE + seekable := !linuxerr.Equals(linuxerr.ESPIPE, err) // We expect regular files to be seekable, as this is required for them to // be memory-mappable. if !seekable && fileType == unix.S_IFREG { @@ -301,7 +302,7 @@ func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOp mask := opts.Mask & linux.STATX_ALL var s unix.Statx_t err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s) - if err == syserror.ENOSYS { + if linuxerr.Equals(linuxerr.ENOSYS, err) { // Fallback to fstat(2), if statx(2) is not supported on the host. // // TODO(b/151263641): Remove fallback. diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index ca85f5601..8cce36212 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/socket/control" @@ -160,7 +161,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // block (and only for stream sockets). err = syserror.EAGAIN } - if n > 0 && err != syserror.EAGAIN { + if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) { // The caller may need to block to send more data, but // otherwise there isn't anything that can be done about an // error with a partial write. diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 0f9e20a84..2cf360065 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -211,7 +212,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from tty_check_change() // to -ENOTTY. - if err == syserror.EIO { + if linuxerr.Equals(linuxerr.EIO, err) { return 0, syserror.ENOTTY } return 0, err diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 63b465859..95d7ebe2e 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -17,7 +17,7 @@ package host import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/errors/linuxerr" ) func toTimespec(ts linux.StatxTimestamp, omit bool) unix.Timespec { @@ -44,5 +44,5 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { - return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK + return linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) } diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index b7d13cced..d53937db6 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -104,6 +104,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/log", @@ -135,6 +136,7 @@ go_test( ":kernfs", "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 8fac53c60..20319ab76 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" @@ -411,7 +412,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v defer rp.Mount().EndWrite() childI, err := parent.inode.NewDir(ctx, pc, opts) if err != nil { - if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { + if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) { return err } childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode) @@ -546,7 +547,7 @@ afterTrailingSymlink: } // Determine whether or not we need to create a file. child, err := fs.stepExistingLocked(ctx, rp, parent, false /* mayFollowSymlinks */) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { // Already checked for searchability above; now check for writability. if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { return nil, err @@ -684,10 +685,12 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } return syserror.EBUSY } - switch err := checkCreateLocked(ctx, rp.Credentials(), newName, dstDir); err { - case nil: + + err = checkCreateLocked(ctx, rp.Credentials(), newName, dstDir) + switch { + case err == nil: // Ok, continue with rename as replacement. - case syserror.EEXIST: + case linuxerr.Equals(linuxerr.EEXIST, err): if noReplace { // Won't overwrite existing node since RENAME_NOREPLACE was requested. return syserror.EEXIST diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 1cd3137e6..de046ce1f 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" @@ -318,10 +319,10 @@ func TestDirFDReadWrite(t *testing.T) { defer fd.DecRef(sys.Ctx) // Read/Write should fail for directory FDs. - if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { + if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) { t.Fatalf("Read for directory FD failed with unexpected error: %v", err) } - if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EBADF { + if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); !linuxerr.Equals(linuxerr.EBADF, err) { t.Fatalf("Write for directory FD failed with unexpected error: %v", err) } } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 5504476c8..ed730e215 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 45aa5a494..8fd51e9d0 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -349,7 +350,7 @@ func (d *dentry) copyXattrsLocked(ctx context.Context) error { lowerXattrs, err := vfsObj.ListXattrAt(ctx, d.fs.creds, lowerPop, 0) if err != nil { - if err == syserror.EOPNOTSUPP { + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { // There are no guarantees as to the contents of lowerXattrs. return nil } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 6b6fa0bd5..81745bccd 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -218,7 +219,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str Start: parentVD, Path: childPath, }, &vfs.GetDentryOptions{}) - if err == syserror.ENOENT || err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { // The file doesn't exist on this layer. Proceed to the next one. return true } @@ -352,7 +353,7 @@ func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, nam }, &vfs.StatOptions{ Mask: linux.STATX_TYPE, }) - if err == syserror.ENOENT || err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { // The file doesn't exist on this layer. Proceed to the next // one. return true @@ -811,7 +812,7 @@ afterTrailingSymlink: // Determine whether or not we need to create a file. parent.dirMu.Lock() child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err == syserror.ENOENT && mayCreate { + if linuxerr.Equals(linuxerr.ENOENT, err) && mayCreate { fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout) parent.dirMu.Unlock() return fd, err @@ -1094,7 +1095,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa whiteouts map[string]bool ) replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } if replaced != nil { @@ -1177,7 +1178,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa Root: replaced.upperVD, Start: replaced.upperVD, Path: fspath.Parse(whiteoutName), - }); err != nil && err != syserror.EEXIST { + }); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RenameAt failure: %v", err)) } } @@ -1344,7 +1345,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error Root: child.upperVD, Start: child.upperVD, Path: fspath.Parse(whiteoutName), - }); err != nil && err != syserror.EEXIST { + }); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)) } } diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 2b628bd55..1d3d2d95f 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -81,6 +81,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/refs", @@ -119,6 +120,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", @@ -127,7 +129,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 177cb828f..ab47ea5a7 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -33,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -679,7 +679,7 @@ func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { continue } if err := d.stack.Statistics(stat, line.prefix); err != nil { - if err == syserror.EOPNOTSUPP { + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) } else { log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index e534fbca8..14f806c3c 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -23,13 +23,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -227,7 +227,7 @@ func TestTasks(t *testing.T) { defer fd.DecRef(s.Ctx) buf := make([]byte, 1) bufIOSeq := usermem.BytesIOSequence(buf) - if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR { + if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) { t.Errorf("wrong error reading directory: %v", err) } } @@ -237,7 +237,7 @@ func TestTasks(t *testing.T) { s.Creds, s.PathOpAtRoot("/proc/9999"), &vfs.OpenOptions{}, - ); err != syserror.ENOENT { + ); !linuxerr.Equals(linuxerr.ENOENT, err) { t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err) } } diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index e21fddd7f..341b4f904 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -118,6 +118,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/sentry/contexttest", "//pkg/sentry/fs/lock", diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index 2f856ce36..418c7994e 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -114,7 +115,7 @@ func TestNonblockingWriteError(t *testing.T) { } openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != syserror.ENXIO { + if !linuxerr.Equals(linuxerr.ENXIO, err) { t.Fatalf("expected ENXIO, but got error: %v", err) } } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index d473a922d..1d855234c 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/marshal/primitive", @@ -41,6 +42,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/testutil", @@ -48,7 +50,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 3582d14c9..e84452421 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -195,7 +196,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { @@ -218,7 +219,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { @@ -238,7 +239,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { @@ -261,7 +262,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi Root: parent.lowerVD, Start: parent.lowerVD, }, &vfs.StatOptions{}) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { @@ -327,7 +328,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }, &vfs.OpenOptions{ Flags: linux.O_RDONLY, }) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { @@ -341,7 +342,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Size: sizeOfStringInt32, }) - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { @@ -359,7 +360,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Size: sizeOfStringInt32, }) - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) } if err != nil { @@ -375,7 +376,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Size: sizeOfStringInt32, }) - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) } if err != nil { @@ -465,7 +466,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s } childVD, err := parent.getLowerAt(ctx, vfsObj, name) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { // The file was previously accessed. If the // file does not exist now, it indicates an // unexpected modification to the file system. @@ -480,7 +481,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // The Merkle tree file was previous accessed. If it // does not exist now, it indicates an unexpected // modification to the file system. - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) } if err != nil { @@ -551,7 +552,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } childVD, err := parent.getLowerAt(ctx, vfsObj, name) - if parent.verityEnabled() && err == syserror.ENOENT { + if parent.verityEnabled() && linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) } if err != nil { @@ -564,7 +565,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { if parent.verityEnabled() { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) } @@ -854,7 +855,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // The file should exist, as we succeeded in finding its dentry. If it's // missing, it indicates an unexpected modification to the file system. if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err @@ -877,7 +878,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // dentry. If it's missing, it indicates an unexpected modification to // the file system. if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err @@ -902,7 +903,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf Flags: linux.O_WRONLY | linux.O_APPEND, }) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err @@ -919,7 +920,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf Flags: linux.O_WRONLY | linux.O_APPEND, }) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 969003613..c5f59d851 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -45,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -358,7 +359,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // If runtime enable is allowed, the root merkle tree may be absent. We // should create the tree file. - if err == syserror.ENOENT && fs.allowRuntimeEnable { + if linuxerr.Equals(linuxerr.ENOENT, err) && fs.allowRuntimeEnable { lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, Start: lowerVD, @@ -451,7 +452,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Name: childrenOffsetXattr, Size: sizeOfStringInt32, }) - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) } if err != nil { @@ -470,7 +471,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Name: childrenSizeXattr, Size: sizeOfStringInt32, }) - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) } if err != nil { @@ -487,7 +488,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt }, &vfs.OpenOptions{ Flags: linux.O_RDONLY, }) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) } if err != nil { @@ -1227,7 +1228,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { @@ -1349,7 +1350,7 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 5c78a0019..65465b814 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" @@ -31,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -476,7 +476,7 @@ func TestOpenNonexistentFile(t *testing.T) { // Ensure open an unexpected file in the parent directory fails with // ENOENT rather than verification failure. - if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); err != syserror.ENOENT { + if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.ENOENT, err) { t.Errorf("OpenAt unexpected error: %v", err) } } @@ -767,7 +767,7 @@ func TestOpenDeletedFileFails(t *testing.T) { } // Ensure reopening the verity enabled file fails. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("got OpenAt error: %v, expected EIO", err) } }) @@ -829,7 +829,7 @@ func TestOpenRenamedFileFails(t *testing.T) { } // Ensure reopening the verity enabled file fails. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("got OpenAt error: %v, expected EIO", err) } }) @@ -1063,14 +1063,14 @@ func TestDeletedSymlinkFileReadFails(t *testing.T) { Root: root, Start: root, Path: fspath.Parse(symlink), - }); err != syserror.EIO { + }); !linuxerr.Equals(linuxerr.EIO, err) { t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err) } if tc.testWalk { fileInSymlinkDirectory := symlink + "/verity-test-file" // Ensure opening the verity enabled file in the symlink directory fails. - if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Open succeeded with modified symlink: %v", err) } } @@ -1195,14 +1195,14 @@ func TestModifiedSymlinkFileReadFails(t *testing.T) { Root: root, Start: root, Path: fspath.Parse(symlink), - }); err != syserror.EIO { + }); !linuxerr.Equals(linuxerr.EIO, err) { t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err) } if tc.testWalk { fileInSymlinkDirectory := symlink + "/verity-test-file" // Ensure opening the verity enabled file in the symlink directory fails. - if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Open succeeded with modified symlink: %v", err) } } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a82d641da..9a4b08469 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -226,6 +226,7 @@ go_library( "//pkg/context", "//pkg/coverage", "//pkg/cpuid", + "//pkg/errors/linuxerr", "//pkg/eventchannel", "//pkg/fspath", "//pkg/goid", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 62777faa8..8786a70b5 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -23,12 +23,12 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // FDFlags define flags for an individual descriptor. @@ -156,7 +156,7 @@ func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. if file.SupportsLocks() { err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) - if err != nil && err != syserror.ENOLCK { + if err != nil && !linuxerr.Equals(linuxerr.ENOLCK, err) { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) } } diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 34c617b08..af46b3e08 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -47,6 +47,7 @@ go_test( library = ":pipe", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/syserror", diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index d6fb0fdb8..d25cf658e 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" @@ -258,7 +259,7 @@ func TestNonblockingWriteOpenFileNoReaders(t *testing.T) { ctx := newSleeperContext(t) f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != syserror.ENXIO { + if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); !linuxerr.Equals(linuxerr.ENXIO, err) { t.Fatalf("Nonblocking open for write failed unexpected error %v.", err) } } diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index a6287fd6a..20563f02a 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -294,7 +294,7 @@ func (t *Task) isYAMADescendantOfLocked(ancestor *Task) bool { // Precondition: the TaskSet mutex must be locked (for reading or writing). func (t *Task) hasYAMAExceptionForLocked(tracer *Task) bool { - allowed, ok := t.k.ptraceExceptions[t] + allowed, ok := t.k.ptraceExceptions[t.tg.leader] if !ok { return false } diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index ecbe8f920..07533d982 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -19,6 +19,7 @@ import ( "runtime/trace" "time" + "gvisor.dev/gvisor/pkg/errors/linuxerr" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -45,7 +46,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time. err := t.BlockWithDeadline(C, true, deadline) // Timeout, explicitly return a remaining duration of 0. - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, err } diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 601fc0d3a..1874f74e5 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/metric" @@ -357,7 +358,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle t.Arch().SetReturn(uintptr(rval)) } else { t.Debugf("vsyscall %d, caller %x: emulated syscall returned error: %v", sysno, t.Arch().Value(caller), err) - if err == syserror.EFAULT { + if linuxerr.Equals(linuxerr.EFAULT, err) { t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) // A return is not emulated in this case. diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 80f862628..54bfed644 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/abi/linux/errno", "//pkg/context", "//pkg/cpuid", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/rand", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 8fc3e2a79..4c7666e33 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -621,7 +622,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) { info, err := parseHeader(ctx, f) if err != nil { - if err == syserror.ENOEXEC { + if linuxerr.Equals(linuxerr.ENOEXEC, err) { // Bad interpreter. err = syserror.ELIBBAD } diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b417c2da7..69aff21b6 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -125,6 +125,7 @@ go_library( "//pkg/abi/linux", "//pkg/atomicbitops", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/refs", @@ -156,6 +157,7 @@ go_test( library = ":mm", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/sentry/arch", "//pkg/sentry/contexttest", @@ -163,7 +165,6 @@ go_test( "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index 1304b0a2f..84cb8158d 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/contexttest" @@ -25,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -171,7 +171,7 @@ func TestIOAfterUnmap(t *testing.T) { } n, err = mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { + if !linuxerr.Equals(linuxerr.EFAULT, err) { t.Errorf("CopyIn got err %v want EFAULT", err) } if n != 0 { @@ -212,7 +212,7 @@ func TestIOAfterMProtect(t *testing.T) { // Without IgnorePermissions, CopyOut should no longer succeed. n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { + if !linuxerr.Equals(linuxerr.EFAULT, err) { t.Errorf("CopyOut got err %v want EFAULT", err) } if n != 0 { @@ -249,7 +249,7 @@ func TestAIOPrepareAfterDestroy(t *testing.T) { mm.DestroyAIOContext(ctx, id) // Prepare should fail because aioCtx should be destroyed. - if err := aioCtx.Prepare(); err != syserror.EINVAL { + if err := aioCtx.Prepare(); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("aioCtx.Prepare got err %v want nil", err) } else if err == nil { aioCtx.CancelPendingRequest() diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 7ad6b7c21..f46f85eb1 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" @@ -855,10 +856,10 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u mm.activeMu.Unlock() mm.mappingMu.RUnlock() // Linux: mm/mlock.c:__mlock_posix_error_return() - if err == syserror.EFAULT { + if linuxerr.Equals(linuxerr.EFAULT, err) { return syserror.ENOMEM } - if err == syserror.ENOMEM { + if linuxerr.Equals(linuxerr.ENOMEM, err) { return syserror.EAGAIN } return err diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 578852c3f..9e5c52923 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -25,29 +25,6 @@ import ( var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = unix.SIGILL - - // vcpuSErrBounce is the event of system error for bouncing KVM. - vcpuSErrBounce = kvmVcpuEvents{ - exception: exception{ - sErrPending: 1, - }, - } - - // vcpuSErrNMI is the event of system error to trigger sigbus. - vcpuSErrNMI = kvmVcpuEvents{ - exception: exception{ - sErrPending: 1, - sErrHasEsr: 1, - sErrEsr: _ESR_ELx_SERR_NMI, - }, - } - - // vcpuExtDabt is the event of ext_dabt. - vcpuExtDabt = kvmVcpuEvents{ - exception: exception{ - extDabtPending: 1, - }, - } ) // getTLS returns the value of TPIDR_EL0 register. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 07fc4f216..f105fdbd0 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -80,11 +80,18 @@ func getHypercallID(addr uintptr) int { // //go:nosplit func bluepillStopGuest(c *vCPU) { + // vcpuSErrBounce is the event of system error for bouncing KVM. + vcpuSErrBounce := &kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + }, + } + if _, _, errno := unix.RawSyscall( // escapes: no. unix.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { + uintptr(unsafe.Pointer(vcpuSErrBounce))); errno != 0 { throw("bounce sErr injection failed") } } @@ -93,12 +100,21 @@ func bluepillStopGuest(c *vCPU) { // //go:nosplit func bluepillSigBus(c *vCPU) { + // vcpuSErrNMI is the event of system error to trigger sigbus. + vcpuSErrNMI := &kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 1, + sErrEsr: _ESR_ELx_SERR_NMI, + }, + } + // Host must support ARM64_HAS_RAS_EXTN. if _, _, errno := unix.RawSyscall( // escapes: no. unix.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { + uintptr(unsafe.Pointer(vcpuSErrNMI))); errno != 0 { if errno == unix.EINVAL { throw("No ARM64_HAS_RAS_EXTN feature in host.") } @@ -110,11 +126,18 @@ func bluepillSigBus(c *vCPU) { // //go:nosplit func bluepillExtDabt(c *vCPU) { + // vcpuExtDabt is the event of ext_dabt. + vcpuExtDabt := &kvmVcpuEvents{ + exception: exception{ + extDabtPending: 1, + }, + } + if _, _, errno := unix.RawSyscall( // escapes: no. unix.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 { + uintptr(unsafe.Pointer(vcpuExtDabt))); errno != 0 { throw("ext_dabt injection failed") } } diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1b0a6e0a7..f6aa519b1 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -140,22 +140,15 @@ func (c *vCPU) initArchState() error { // vbar_el1 reg.id = _KVM_ARM64_REGS_VBAR_EL1 - - fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() - offset := fromLocation & (1<<11 - 1) - if offset != 0 { - offset = 1<<11 - offset - } - - toLocation := fromLocation + offset - data = uint64(ring0.KernelStartAddress | toLocation) + vectorLocation := reflect.ValueOf(ring0.Vectors).Pointer() + data = uint64(ring0.KernelStartAddress | vectorLocation) if err := c.setOneRegister(®); err != nil { return err } // Use the address of the exception vector table as // the MMIO address base. - arm64HypercallMMIOBase = toLocation + arm64HypercallMMIOBase = vectorLocation // Initialize the PCID database. if hasGuestPCID { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 3c6511ead..3950caa0f 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fdnotifier", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index b9473da6c..38cb2c99c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" @@ -714,7 +715,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } if ch != nil { if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 6b83698ad..ed85404da 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -17,6 +17,7 @@ go_library( "//pkg/abi/linux/errno", "//pkg/bits", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index c9f784cf4..d53f23a9a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -559,7 +560,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 96c425619..e828982eb 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/abi/linux", "//pkg/abi/linux/errno", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/marshal", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 66d0fcb47..11f75628c 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -38,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" @@ -2809,7 +2810,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if n > 0 { return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil } - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) @@ -2877,7 +2878,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b // became available between when we last checked and when we setup // the notification. if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return int(total), syserr.ErrTryAgain } // handleIOError will consume errors from t.Block if needed. diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index c9cbefb3a..5c3cdef6a 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -39,6 +39,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index db7b1affe..8ccdadae9 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -23,6 +23,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -518,7 +519,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -719,7 +720,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if total > 0 { err = nil } - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err) diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index c39e317ff..08a00a12f 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -17,6 +17,7 @@ package unix import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -236,7 +237,7 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()), Endpoint: bep, }) - if err == syserror.EEXIST { + if linuxerr.Equals(linuxerr.EEXIST, err) { return syserr.ErrAddressInUse } return syserr.FromError(err) diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 3e801182c..7f02807c5 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -20,7 +21,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/state/statefile", - "//pkg/syserror", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 2f0aba4e2..e9d544f3d 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -20,6 +20,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -27,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" - "gvisor.dev/gvisor/pkg/syserror" ) var previousMetadata map[string]string @@ -88,7 +88,7 @@ func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Wat // ENOSPC is a state file error. This error can only come from // writing the state file, and not from fs.FileOperations.Fsync // because we wrap those in kernel.TaskSet.flushWritesToFiles. - if err == syserror.ENOSPC { + if linuxerr.Equals(linuxerr.ENOSPC, err) { err = ErrStateFile{err} } diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD index b8d1bd415..f2c55588f 100644 --- a/pkg/sentry/syscalls/BUILD +++ b/pkg/sentry/syscalls/BUILD @@ -11,6 +11,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/errors/linuxerr", "//pkg/sentry/arch", "//pkg/sentry/kernel", "//pkg/sentry/kernel/epoll", diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index 3b4d79889..02debfc7e 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -163,7 +164,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return nil, nil } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 408a6c422..a2f612f45 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -64,6 +64,7 @@ go_library( "//pkg/abi/linux", "//pkg/bpf", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/marshal", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 6eabfd219..165922332 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -94,13 +95,13 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er if errno, ok := syserror.TranslateError(errOrig); ok { translatedErr = errno } - switch translatedErr { - case io.EOF: + switch { + case translatedErr == io.EOF: // EOF is always consumed. If this is a partial read/write // (result != 0), the application will see that, otherwise // they will see 0. return true, nil - case syserror.EFBIG: + case linuxerr.Equals(linuxerr.EFBIG, translatedErr): t := kernel.TaskFromContext(ctx) if t == nil { panic("I/O error should only occur from a context associated with a Task") @@ -113,7 +114,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er // Simultaneously send a SIGXFSZ per setrlimit(2). t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t)) return true, syserror.EFBIG - case syserror.EINTR: + case linuxerr.Equals(linuxerr.EINTR, translatedErr): // The syscall was interrupted. Return nil if it completed // partially, otherwise return the error code that the syscall // needs (to indicate to the kernel what it should do). @@ -128,21 +129,21 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er return true, errOrig } - switch translatedErr { - case syserror.EINTR: + switch { + case linuxerr.Equals(linuxerr.EINTR, translatedErr): // Syscall interrupted, but completed a partial // read/write. Like ErrWouldBlock, since we have a // partial read/write, we consume the error and return // the partial result. return true, nil - case syserror.EFAULT: + case linuxerr.Equals(linuxerr.EFAULT, translatedErr): // EFAULT is only shown the user if nothing was // read/written. If we read something (this case), they see // a partial read/write. They will then presumably try again // with an incremented buffer, which will EFAULT with // result == 0. return true, nil - case syserror.EPIPE: + case linuxerr.Equals(linuxerr.EPIPE, translatedErr): // Writes to a pipe or socket will return EPIPE if the other // side is gone. The partial write is returned. EPIPE will be // returned on the next call. @@ -150,15 +151,17 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er // TODO(gvisor.dev/issue/161): In some cases SIGPIPE should // also be sent to the application. return true, nil - case syserror.ENOSPC: + case linuxerr.Equals(linuxerr.ENOSPC, translatedErr): // Similar to EPIPE. Return what we wrote this time, and let // ENOSPC be returned on the next call. return true, nil - case syserror.ECONNRESET, syserror.ETIMEDOUT: + case linuxerr.Equals(linuxerr.ECONNRESET, translatedErr): + fallthrough + case linuxerr.Equals(linuxerr.ETIMEDOUT, translatedErr): // For TCP sendfile connections, we may have a reset or timeout. But we // should just return n as the result. return true, nil - case syserror.EWOULDBLOCK: + case linuxerr.Equals(linuxerr.EWOULDBLOCK, translatedErr): // Syscall would block, but completed a partial read/write. // This case should only be returned by IssueIO for nonblocking // files. Since we have a partial read/write, we consume diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 70e8569a8..c338a4cc9 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -17,6 +17,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -134,7 +135,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S var err error v, err = waitForRequest(ctx, t, haveDeadline, deadline) if err != nil { - if count > 0 || err == syserror.ETIMEDOUT { + if count > 0 || linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return uintptr(count), nil, nil } return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 90a719ba2..6109a2d8c 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -394,8 +395,8 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode } var newFile *fs.File - switch err { - case nil: + switch { + case err == nil: // Like sys_open, check for a few things about the // filesystem before trying to get a reference to the // fs.File. The same constraints on Check apply. @@ -418,7 +419,7 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } defer newFile.DecRef(t) - case syserror.ENOENT: + case linuxerr.Equals(linuxerr.ENOENT, err): // File does not exist. Proceed with creation. // Do we have write permissions on the parent? @@ -1178,12 +1179,12 @@ func mkdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMod // Does this directory exist already? remainingTraversals := uint(linux.MaxSymlinkTraversals) f, err := t.MountNamespace().FindInode(t, root, d, name, &remainingTraversals) - switch err { - case nil: + switch { + case err == nil: // The directory existed. defer f.DecRef(t) return syserror.EEXIST - case syserror.EACCES: + case linuxerr.Equals(linuxerr.EACCES, err): // Permission denied while walking to the directory. return err default: @@ -1464,7 +1465,7 @@ func readlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, bufAddr hostarc } s, err := d.Inode.Readlink(t) - if err == syserror.ENOLINK { + if linuxerr.Equals(linuxerr.ENOLINK, err) { return syserror.EINVAL } if err != nil { diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index da548a14a..024632475 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -128,7 +129,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // Wait for a notification. timeout, err = t.BlockWithTimeout(ch, !forever, timeout) if err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = nil } return timeout, 0, err @@ -404,7 +405,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) { remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) // On an interrupt poll(2) is restarted with the remaining timeout. - if err == syserror.EINTR { + if linuxerr.Equals(linuxerr.EINTR, err) { t.SetSyscallRestartBlock(&pollRestartBlock{ pfdAddr: pfdAddr, nfds: nfds, @@ -463,7 +464,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -493,7 +494,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -538,7 +539,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 9890dd946..30c15af4a 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -98,7 +99,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.PR_SET_NAME: addr := args[1].Pointer() name, err := t.CopyInString(addr, linux.TASK_COMM_LEN-1) - if err != nil && err != syserror.ENAMETOOLONG { + if err != nil && !linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { return 0, nil, err } t.SetName(name) diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 13e5e3a51..0f9329fe8 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -331,7 +332,7 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index c84260080..cb320c536 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -81,7 +82,7 @@ func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if err := semTimedOp(t, id, ops, true, timeout.ToDuration()); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, nil, syserror.EAGAIN } return 0, nil, err diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 27a7f7fe1..db763c68e 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -90,7 +91,7 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC } info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) - if err := target.SendGroupSignal(info); err != syserror.ESRCH { + if err := target.SendGroupSignal(info); !linuxerr.Equals(linuxerr.ESRCH, err) { return 0, nil, err } } @@ -130,7 +131,7 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) err := tg.SendSignal(info) - if err == syserror.ESRCH { + if linuxerr.Equals(linuxerr.ESRCH, err) { // ESRCH is ignored because it means the task // exited while we were iterating. This is a // race which would not normally exist on @@ -174,7 +175,7 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) // See note above regarding ESRCH race above. - if err := tg.SendSignal(info); err != syserror.ESRCH { + if err := tg.SendSignal(info); !linuxerr.Equals(linuxerr.ESRCH, err) { lastErr = err } } @@ -433,7 +434,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne return 0, nil, syserror.EPERM } - if err := target.SendGroupSignal(&info); err != syserror.ESRCH { + if err := target.SendGroupSignal(&info); !linuxerr.Equals(linuxerr.ESRCH, err) { return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index e07917613..3bd21a911 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -305,7 +306,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, if peerRequested { // NOTE(magi): Linux does not give you an error if it can't // write the data back out so neither do we. - if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL { + if err := writeAddress(t, peer, peerLen, addr, addrLen); linuxerr.Equals(linuxerr.EINVAL, err) { return 0, err } } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 5c3b3dee2..2ec74b33a 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -209,11 +210,11 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem host timer.Destroy() - switch err { - case syserror.ETIMEDOUT: + switch { + case linuxerr.Equals(linuxerr.ETIMEDOUT, err): // Slept for entire timeout. return nil - case syserror.ErrInterrupted: + case err == syserror.ErrInterrupted: // Interrupted. remaining := end.Sub(c.Now()) if remaining <= 0 { diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 95bfe6606..cff355550 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -301,7 +302,7 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 28ad6a60e..37fb67f80 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -18,6 +18,7 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -217,7 +218,7 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, s func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) { name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) if err != nil { - if err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { return "", syserror.ERANGE } return "", err diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 5ce0bc714..a73f096ff 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -41,6 +41,7 @@ go_library( "//pkg/abi/linux", "//pkg/bits", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/gohacks", "//pkg/hostarch", diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index 047d955b6..7aff01343 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -174,7 +175,7 @@ func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents i haveDeadline = true } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = nil } return 0, nil, err diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index a69c80edd..b16773d65 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -132,7 +133,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // Wait for a notification. timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout) if err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = nil } return timeout, 0, err @@ -410,7 +411,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) { remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) // On an interrupt poll(2) is restarted with the remaining timeout. - if err == syserror.EINTR { + if linuxerr.Equals(linuxerr.EINTR, err) { t.SetSyscallRestartBlock(&pollRestartBlock{ pfdAddr: pfdAddr, nfds: nfds, @@ -462,7 +463,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -492,7 +493,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -539,7 +540,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index b863d7b84..bbfa4c6d7 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -120,7 +121,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -275,7 +276,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -371,7 +372,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -525,7 +526,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 69f69e3af..9a4b5e5fc 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -309,7 +310,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, if peerRequested { // NOTE(magi): Linux does not give you an error if it can't // write the data back out so neither do we. - if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL { + if err := writeAddress(t, peer, peerLen, addr, addrLen); linuxerr.Equals(linuxerr.EINVAL, err) { return 0, err } } diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index c261050c6..c779c6465 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -18,6 +18,7 @@ import ( "bytes" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -295,7 +296,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) { name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) if err != nil { - if err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { return "", syserror.ERANGE } return "", err diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index 581862ee2..e7073ec87 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -132,7 +132,7 @@ func Init() error { // always be the case for a newly mapped page from /dev/shm. If we obtain // the shared memory through some other means in the future, we may have to // explicitly zero the page. - mmap, err := unix.Mmap(int(file.Fd()), 0, int(RTMemoryStatsSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + mmap, err := memutil.MapFile(0, RTMemoryStatsSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, file.Fd(), 0) if err != nil { return fmt.Errorf("error mapping usage file: %v", err) } diff --git a/pkg/sentry/usage/memory_unsafe.go b/pkg/sentry/usage/memory_unsafe.go index 9e0014ca0..bc1531b91 100644 --- a/pkg/sentry/usage/memory_unsafe.go +++ b/pkg/sentry/usage/memory_unsafe.go @@ -21,7 +21,7 @@ import ( // RTMemoryStatsSize is the size of the RTMemoryStats struct. var RTMemoryStatsSize = unsafe.Sizeof(RTMemoryStats{}) -// RTMemoryStatsPointer casts the address of the byte slice into a RTMemoryStats pointer. -func RTMemoryStatsPointer(b []byte) *RTMemoryStats { - return (*RTMemoryStats)(unsafe.Pointer(&b[0])) +// RTMemoryStatsPointer casts addr to a RTMemoryStats pointer. +func RTMemoryStatsPointer(addr uintptr) *RTMemoryStats { + return (*RTMemoryStats)(unsafe.Pointer(addr)) } diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index ac60fe8bf..a2032162d 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -95,6 +95,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/fspath", @@ -133,6 +134,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/contexttest", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index ef8d8a813..2bc33d424 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsmetric" @@ -708,8 +709,8 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string return names, err } names, err := fd.impl.ListXattr(ctx, size) - if err == syserror.ENOTSUP { - // Linux doesn't actually return ENOTSUP in this case; instead, + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { + // Linux doesn't actually return EOPNOTSUPP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by default // don't exist. diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 1cd607c0a..566ad856a 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -155,10 +156,10 @@ func TestGenCountFD(t *testing.T) { } // Write and PWrite fails. - if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO { + if _, err := fd.Write(ctx, ioseq, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } - if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO { + if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } } @@ -215,10 +216,10 @@ func TestWritable(t *testing.T) { if n, err := fd.Seek(ctx, 1, linux.SEEK_SET); n != 0 && err != nil { t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err) } - if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && err != syserror.EINVAL { + if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("Write: got err (%v, %v), wanted (0, EINVAL)", n, err) } - if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && err != syserror.EINVAL { + if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("PWrite: got err (%v, %v), wanted (0, EINVAL)", n, err) } } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 87fdcf403..b96de247f 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -42,6 +42,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -731,8 +732,8 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede rp.Release(ctx) return names, nil } - if err == syserror.ENOTSUP { - // Linux doesn't actually return ENOTSUP in this case; instead, + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { + // Linux doesn't actually return EOPNOTSUPP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. @@ -830,14 +831,14 @@ func (vfs *VirtualFilesystem) MkdirAllAt(ctx context.Context, currentPath string Path: fspath.Parse(currentPath), } stat, err := vfs.StatAt(ctx, creds, pop, &StatOptions{Mask: linux.STATX_TYPE}) - switch err { - case nil: + switch { + case err == nil: if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.FileTypeMask != linux.ModeDirectory { return syserror.ENOTDIR } // Directory already exists. return nil - case syserror.ENOENT: + case linuxerr.Equals(linuxerr.ENOENT, err): // Expected, we will create the dir. default: return fmt.Errorf("stat failed for %q during directory creation: %w", currentPath, err) @@ -871,7 +872,7 @@ func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, targe Root: root, Start: root, Path: fspath.Parse(target), - }, mkdirOpts); err != nil && err != syserror.EEXIST { + }, mkdirOpts); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) { return fmt.Errorf("failed to create mountpoint %q: %w", target, err) } return nil diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD index d8086688c..b115556f5 100644 --- a/pkg/shim/BUILD +++ b/pkg/shim/BUILD @@ -61,6 +61,7 @@ go_test( library = ":shim", deps = [ "//pkg/shim/utils", + "@com_github_containerd_containerd//errdefs:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) diff --git a/pkg/shim/errors.go b/pkg/shim/errors.go index 311a4b427..75d036411 100644 --- a/pkg/shim/errors.go +++ b/pkg/shim/errors.go @@ -29,6 +29,9 @@ import ( // // TODO(gvisor.dev/issue/6232): Remove after upgrading to containerd v1.4 func errToGRPC(err error) error { + if err == nil { + return nil + } if _, ok := status.FromError(err); ok { return err } diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index 7b3160309..5205fa7e4 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -12,6 +12,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux/errno", + "//pkg/errors/linuxerr", "//pkg/syserror", "//pkg/tcpip", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index fb77ac8bd..7d0a5125b 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -54,7 +55,7 @@ func New(message string, linuxTranslation errno.Errno) *Error { // enables proper blocking semantics. This should temporary address the // class of blocking bugs that keep popping up with the current state of // the error space. - if e == syserror.EWOULDBLOCK { + if err.errno == linuxerr.EWOULDBLOCK.Errno() { e = syserror.ErrWouldBlock } linuxBackwardsTranslations[err.errno] = linuxBackwardsTranslation{err: e, ok: true} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a27e2110b..242e6b7f8 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2372,6 +2372,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.notifyProtocolGoroutine(notifyTickleWorker) return nil } + // Wake up any readers that maybe waiting for the stream to become + // readable. + e.waiterQueue.Notify(waiter.ReadableEvents) } // Close for write. @@ -2394,6 +2397,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.sndQueueInfo.SndClosed = true e.sndQueueInfo.sndQueueMu.Unlock() e.handleClose() + // Wake up any writers that maybe waiting for the stream to become + // writable. + e.waiterQueue.Notify(waiter.WritableEvents) } return nil diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9bbe9bc3e..d1314fcdf 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3451,17 +3451,13 @@ loop: for { switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { case *tcpip.ErrWouldBlock: - select { - case <-ch: - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) - } - break loop - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") + <-ch + // Expect the state to be StateError and subsequent Reads to fail with HardError. + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } + break loop case *tcpip.ErrConnectionReset: break loop default: @@ -3472,14 +3468,27 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) + + checkValid := func() []error { + var errors []error + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) + } + return errors } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + for _, err := range checkValid() { + t.Error(err) } } @@ -6092,15 +6101,10 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { defer c.WQ.EventUnregister(&we) // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + <-ch + newEP, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) } } @@ -6209,12 +6213,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { RcvWnd: 30000, }) - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) + checkValid := func() []error { + var errors []error + if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) + } + return errors } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) + } + for _, err := range checkValid() { + t.Error(err) + } + if t.Failed() { + t.FailNow() } we, ch := waiter.NewChannelEntry(nil) @@ -6225,15 +6243,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { _, _, err = c.EP.Accept(nil) if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + <-ch + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) } } } @@ -7483,7 +7496,7 @@ func TestTCPUserTimeout(t *testing.T) { select { case <-notifyCh: case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) } // No packet should be received as the connection should be silently diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD index d7decd78a..229a8341b 100644 --- a/pkg/usermem/BUILD +++ b/pkg/usermem/BUILD @@ -30,6 +30,7 @@ go_test( library = ":usermem", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/safemem", "//pkg/syserror", diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go index 9b697b593..6ef2b571f 100644 --- a/pkg/usermem/usermem_test.go +++ b/pkg/usermem/usermem_test.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" @@ -272,7 +273,7 @@ func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { src := BytesIOSequence([]byte(s)) initial := []int32{1, 2} dsts := append([]int32(nil), initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL { + if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL) } if !reflect.DeepEqual(dsts, initial) { diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index a79afbdc4..c7b26746b 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/control/server", "//pkg/coverage", "//pkg/cpuid", + "//pkg/errors/linuxerr", "//pkg/eventchannel", "//pkg/fd", "//pkg/flipcall", diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index c4590aab1..7fce2b708 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -25,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -41,7 +42,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" @@ -1039,8 +1039,8 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *config.Config, mn maxTraversals := uint(0) tmp, err := mns.FindInode(ctx, root, root, "tmp", &maxTraversals) - switch err { - case nil: + switch { + case err == nil: // Found '/tmp' in filesystem, check if it's empty. defer tmp.DecRef(ctx) f, err := tmp.Inode.GetFile(ctx, tmp, fs.FileFlags{Read: true, Directory: true}) @@ -1061,7 +1061,7 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *config.Config, mn log.Infof("Mounting internal tmpfs on top of empty %q", "/tmp") fallthrough - case syserror.ENOENT: + case linuxerr.Equals(linuxerr.ENOENT, err): // No '/tmp' found (or fallthrough from above). Safe to mount internal // tmpfs. tmpMount := specs.Mount{ diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 52aa33529..ca1a86e39 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/devices/memdev" @@ -656,8 +657,8 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config Path: fspath.Parse("/tmp"), } fd, err := c.k.VFS().OpenAt(ctx, creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY}) - switch err { - case nil: + switch { + case err == nil: defer fd.DecRef(ctx) err := fd.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { @@ -666,10 +667,10 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config } return nil })) - switch err { - case nil: + switch { + case err == nil: log.Infof(`Mounting internal tmpfs on top of empty "/tmp"`) - case syserror.ENOTEMPTY: + case linuxerr.Equals(linuxerr.ENOTEMPTY, err): // If more than "." and ".." is found, skip internal tmpfs to prevent // hiding existing files. log.Infof(`Skipping internal tmpfs mount for "/tmp" because it's not empty`) @@ -679,7 +680,7 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config } fallthrough - case syserror.ENOENT: + case linuxerr.Equals(linuxerr.ENOENT, err): // No '/tmp' found (or fallthrough from above). It's safe to mount internal // tmpfs. tmpMount := specs.Mount{ @@ -692,7 +693,7 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config _, err := c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{mount: &tmpMount}) return err - case syserror.ENOTDIR: + case linuxerr.Equals(linuxerr.ENOTDIR, err): // Not a dir?! Let it be. return nil diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 0cac0bf1b..7e89ba2b3 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -180,9 +180,7 @@ func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd) + fd, sa, err := dut.AcceptWithErrno(context.Background(), t, sockfd) if fd < 0 { t.Fatalf("failed to accept: %s", err) } @@ -209,9 +207,7 @@ func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.BindWithErrno(ctx, t, fd, sa) + ret, err := dut.BindWithErrno(context.Background(), t, fd, sa) if ret != 0 { t.Fatalf("failed to bind socket: %s", err) } @@ -238,9 +234,7 @@ func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa un func (dut *DUT) Close(t *testing.T, fd int32) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.CloseWithErrno(ctx, t, fd) + ret, err := dut.CloseWithErrno(context.Background(), t, fd) if ret != 0 { t.Fatalf("failed to close: %s", err) } @@ -266,9 +260,7 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.ConnectWithErrno(ctx, t, fd, sa) + ret, err := dut.ConnectWithErrno(context.Background(), t, fd, sa) // Ignore 'operation in progress' error that can be returned when the socket // is non-blocking. if err != unix.EINPROGRESS && ret != 0 { @@ -297,9 +289,7 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd) + ret, sa, err := dut.GetSockNameWithErrno(context.Background(), t, sockfd) if ret != 0 { t.Fatalf("failed to getsockname: %s", err) } @@ -349,9 +339,7 @@ func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, opt func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen) + ret, optval, err := dut.GetSockOptWithErrno(context.Background(), t, sockfd, level, optname, optlen) if ret != 0 { t.Fatalf("failed to GetSockOpt: %s", err) } @@ -378,9 +366,7 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, l func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname) + ret, intval, err := dut.GetSockOptIntWithErrno(context.Background(), t, sockfd, level, optname) if ret != 0 { t.Fatalf("failed to GetSockOptInt: %s", err) } @@ -405,9 +391,7 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname) + ret, timeval, err := dut.GetSockOptTimevalWithErrno(context.Background(), t, sockfd, level, optname) if ret != 0 { t.Fatalf("failed to GetSockOptTimeval: %s", err) } @@ -434,9 +418,7 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, so func (dut *DUT) GetSockOptTCPInfo(t *testing.T, sockfd int32) linux.TCPInfo { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, info, err := dut.GetSockOptTCPInfoWithErrno(ctx, t, sockfd) + ret, info, err := dut.GetSockOptTCPInfoWithErrno(context.Background(), t, sockfd) if ret != 0 || err != unix.Errno(0) { t.Fatalf("failed to GetSockOptTCPInfo: %s", err) } @@ -463,9 +445,7 @@ func (dut *DUT) GetSockOptTCPInfoWithErrno(ctx context.Context, t *testing.T, so func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog) + ret, err := dut.ListenWithErrno(context.Background(), t, sockfd, backlog) if ret != 0 { t.Fatalf("failed to listen: %s", err) } @@ -510,13 +490,7 @@ func (dut *DUT) PollOne(t *testing.T, fd int32, events int16, timeout time.Durat func (dut *DUT) Poll(t *testing.T, pfds []unix.PollFd, timeout time.Duration) []unix.PollFd { t.Helper() - ctx := context.Background() - var cancel context.CancelFunc - if timeout >= 0 { - ctx, cancel = context.WithTimeout(ctx, timeout+RPCTimeout) - defer cancel() - } - ret, result, err := dut.PollWithErrno(ctx, t, pfds, timeout) + ret, result, err := dut.PollWithErrno(context.Background(), t, pfds, timeout) if ret < 0 { t.Fatalf("failed to poll: %s", err) } @@ -559,9 +533,7 @@ func (dut *DUT) PollWithErrno(ctx context.Context, t *testing.T, pfds []unix.Pol func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags) + ret, err := dut.SendWithErrno(context.Background(), t, sockfd, buf, flags) if ret == -1 { t.Fatalf("failed to send: %s", err) } @@ -590,9 +562,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, b func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr) + ret, err := dut.SendToWithErrno(context.Background(), t, sockfd, buf, flags, destAddr) if ret == -1 { t.Fatalf("failed to sendto: %s", err) } @@ -625,10 +595,8 @@ func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) { Fd: fd, Nonblocking: nonblocking, } - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - resp, err := dut.posixServer.SetNonblocking(ctx, req) + resp, err := dut.posixServer.SetNonblocking(context.Background(), req) if err != nil { t.Fatalf("failed to call SetNonblocking: %s", err) } @@ -661,9 +629,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, opt func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval) + ret, err := dut.SetSockOptWithErrno(context.Background(), t, sockfd, level, optname, optval) if ret != 0 { t.Fatalf("failed to SetSockOpt: %s", err) } @@ -684,9 +650,7 @@ func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, l func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval) + ret, err := dut.SetSockOptIntWithErrno(context.Background(), t, sockfd, level, optname, optval) if ret != 0 { t.Fatalf("failed to SetSockOptInt: %s", err) } @@ -705,9 +669,7 @@ func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv) + ret, err := dut.SetSockOptTimevalWithErrno(context.Background(), t, sockfd, level, optname, tv) if ret != 0 { t.Fatalf("failed to SetSockOptTimeval: %s", err) } @@ -746,8 +708,7 @@ func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, Type: typ, Protocol: proto, } - ctx := context.Background() - resp, err := dut.posixServer.Socket(ctx, req) + resp, err := dut.posixServer.Socket(context.Background(), req) if err != nil { t.Fatalf("failed to call Socket: %s", err) } @@ -760,9 +721,7 @@ func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags) + ret, buf, err := dut.RecvWithErrno(context.Background(), t, sockfd, len, flags) if ret == -1 { t.Fatalf("failed to recv: %s", err) } @@ -805,9 +764,7 @@ func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Dur func (dut *DUT) Shutdown(t *testing.T, fd, how int32) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) - defer cancel() - ret, err := dut.ShutdownWithErrno(ctx, t, fd, how) + ret, err := dut.ShutdownWithErrno(context.Background(), t, fd, how) if ret != 0 { t.Fatalf("failed to shutdown(%d, %d): %s", fd, how, err) } diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index caa389780..38ae9c1d7 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -31,8 +31,6 @@ var ( Native = false // RPCKeepalive is the gRPC keepalive. RPCKeepalive = 10 * time.Second - // RPCTimeout is the gRPC timeout. - RPCTimeout = 100 * time.Millisecond // dutInfosJSON is the json string that describes information about all the // duts available to use. @@ -124,7 +122,6 @@ func (n *DUTTestNet) SubnetBroadcast() net.IP { // functions. func registerFlags(fs *flag.FlagSet) { fs.BoolVar(&Native, "native", Native, "whether the test is running natively") - fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout") fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive") fs.StringVar(&dutInfosJSON, "dut_infos_json", dutInfosJSON, "json that describes the DUTs") } diff --git a/test/packetimpact/tests/generic_dgram_socket_send_recv_test.go b/test/packetimpact/tests/generic_dgram_socket_send_recv_test.go index 00e0f7690..a9ffafc74 100644 --- a/test/packetimpact/tests/generic_dgram_socket_send_recv_test.go +++ b/test/packetimpact/tests/generic_dgram_socket_send_recv_test.go @@ -48,7 +48,6 @@ func maxUDPPayloadSize(addr net.IP) int { func init() { testbench.Initialize(flag.CommandLine) - testbench.RPCTimeout = 500 * time.Millisecond } func expectedEthLayer(t *testing.T, dut testbench.DUT, socketFD int32, sendTo net.IP) testbench.Layer { @@ -437,9 +436,7 @@ func (test *icmpV6Test) Send(t *testing.T, dut testbench.DUT, bindTo, sendTo net copy(destSockaddr.Addr[:], sendTo.To16()) // Tell the DUT to send a packet out the ICMPv6 socket. - ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) - defer cancel() - gotRet, gotErrno := dut.SendToWithErrno(ctx, t, env.socketFD, bytes, 0, &destSockaddr) + gotRet, gotErrno := dut.SendToWithErrno(context.Background(), t, env.socketFD, bytes, 0, &destSockaddr) if gotErrno != wantErrno { t.Fatalf("got dut.SendToWithErrno(_, _, %d, _, _, %s) = (_, %s), want = (_, %s)", env.socketFD, sendTo, gotErrno, wantErrno) @@ -677,9 +674,7 @@ func (test *udpTest) Send(t *testing.T, dut testbench.DUT, bindTo, sendTo net.IP } // Tell the DUT to send a packet out the UDP socket. - ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) - defer cancel() - gotRet, gotErrno := dut.SendToWithErrno(ctx, t, env.socketFD, payload, 0, destSockaddr) + gotRet, gotErrno := dut.SendToWithErrno(context.Background(), t, env.socketFD, payload, 0, destSockaddr) if gotErrno != wantErrno { t.Fatalf("got dut.SendToWithErrno(_, _, %d, _, _, %s) = (_, %s), want = (_, %s)", env.socketFD, sendTo, gotErrno, wantErrno) diff --git a/test/packetimpact/tests/tcp_connect_icmp_error_test.go b/test/packetimpact/tests/tcp_connect_icmp_error_test.go index 3b4c4cd63..15d603328 100644 --- a/test/packetimpact/tests/tcp_connect_icmp_error_test.go +++ b/test/packetimpact/tests/tcp_connect_icmp_error_test.go @@ -15,9 +15,7 @@ package tcp_connect_icmp_error_test import ( - "context" "flag" - "sync" "testing" "time" @@ -66,35 +64,38 @@ func TestTCPConnectICMPError(t *testing.T) { t.Fatalf("expected SYN, %s", err) } - done := make(chan bool) - defer close(done) - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - var block sync.WaitGroup - block.Add(1) + // Continuously try to read the ICMP error in an attempt to trigger a race + // condition. + start := make(chan struct{}) + done := make(chan struct{}) go func() { - defer wg.Done() - _, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() + defer close(done) - block.Done() + close(start) for { select { case <-done: return default: - if errno := dut.GetSockOptInt(t, clientFD, unix.SOL_SOCKET, unix.SO_ERROR); errno != 0 { - return - } } + const want = unix.EHOSTUNREACH + switch got := unix.Errno(dut.GetSockOptInt(t, clientFD, unix.SOL_SOCKET, unix.SO_ERROR)); got { + case unix.Errno(0): + continue + case want: + return + default: + t.Fatalf("got SO_ERROR = %s, want %s", got, want) + } + } }() - block.Wait() + <-start sendICMPError(t, &conn, tcp) dut.PollOne(t, clientFD, unix.POLLHUP, time.Second) + <-done conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) // The DUT should reply with RST to our ACK as the state should have diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go index 88942904d..46b5ca5d8 100644 --- a/test/packetimpact/tests/tcp_linger_test.go +++ b/test/packetimpact/tests/tcp_linger_test.go @@ -98,20 +98,19 @@ func TestTCPLingerNonZeroTimeout(t *testing.T) { dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) - // Increase timeout as Close will take longer time to - // return when SO_LINGER is set with non-zero timeout. - timeout := lingerDuration + 1*time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() start := time.Now() - dut.CloseWithErrno(ctx, t, acceptFD) - end := time.Now() - diff := end.Sub(start) - - if tt.lingerOn && diff < lingerDuration { - t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) - } else if !tt.lingerOn && diff > 1*time.Second { - t.Errorf("expected close to return within a second, but returned later") + dut.CloseWithErrno(context.Background(), t, acceptFD) + elapsed := time.Since(start) + + expectedMaximum := time.Second + if tt.lingerOn { + expectedMaximum += lingerDuration + if elapsed < lingerDuration { + t.Errorf("expected close to take at least %s, but took %s", lingerDuration, elapsed) + } + } + if elapsed >= expectedMaximum { + t.Errorf("expected close to take at most %s, but took %s", expectedMaximum, elapsed) } if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { @@ -144,20 +143,19 @@ func TestTCPLingerSendNonZeroTimeout(t *testing.T) { sampleData := []byte("Sample Data") dut.Send(t, acceptFD, sampleData, 0) - // Increase timeout as Close will take longer time to - // return when SO_LINGER is set with non-zero timeout. - timeout := lingerDuration + 1*time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() start := time.Now() - dut.CloseWithErrno(ctx, t, acceptFD) - end := time.Now() - diff := end.Sub(start) - - if tt.lingerOn && diff < lingerDuration { - t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) - } else if !tt.lingerOn && diff > 1*time.Second { - t.Errorf("expected close to return within a second, but returned later") + dut.CloseWithErrno(context.Background(), t, acceptFD) + elapsed := time.Since(start) + + expectedMaximum := time.Second + if tt.lingerOn { + expectedMaximum += lingerDuration + if elapsed < lingerDuration { + t.Errorf("expected close to take at least %s, but took %s", lingerDuration, elapsed) + } + } + if elapsed >= expectedMaximum { + t.Errorf("expected close to take at most %s, but took %s", expectedMaximum, elapsed) } samplePayload := &testbench.Payload{Bytes: sampleData} @@ -221,20 +219,19 @@ func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) { dut.Shutdown(t, acceptFD, unix.SHUT_RDWR) - // Increase timeout as Close will take longer time to - // return when SO_LINGER is set with non-zero timeout. - timeout := lingerDuration + 1*time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() start := time.Now() - dut.CloseWithErrno(ctx, t, acceptFD) - end := time.Now() - diff := end.Sub(start) - - if tt.lingerOn && diff < lingerDuration { - t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) - } else if !tt.lingerOn && diff > 1*time.Second { - t.Errorf("expected close to return within a second, but returned later") + dut.CloseWithErrno(context.Background(), t, acceptFD) + elapsed := time.Since(start) + + expectedMaximum := time.Second + if tt.lingerOn { + expectedMaximum += lingerDuration + if elapsed < lingerDuration { + t.Errorf("expected close to take at least %s, but took %s", lingerDuration, elapsed) + } + } + if elapsed >= expectedMaximum { + t.Errorf("expected close to take at most %s, but took %s", expectedMaximum, elapsed) } samplePayload := &testbench.Payload{Bytes: sampleData} @@ -259,9 +256,10 @@ func TestTCPLingerNonEstablished(t *testing.T) { // and return immediately. start := time.Now() dut.CloseWithErrno(context.Background(), t, newFD) - diff := time.Since(start) + elapsed := time.Since(start) - if diff > lingerDuration { - t.Errorf("expected close to return within %s, but returned after %s", lingerDuration, diff) + expectedMaximum := time.Second + if elapsed >= time.Second { + t.Errorf("expected close to take at most %s, but took %s", expectedMaximum, elapsed) } } diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go index 60a2dbf3d..e92e6aa9b 100644 --- a/test/packetimpact/tests/tcp_network_unreachable_test.go +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -41,11 +41,9 @@ func TestTCPSynSentUnreachable(t *testing.T) { defer conn.Close(t) // Bring the DUT to SYN-SENT state with a non-blocking connect. - ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) - defer cancel() sa := unix.SockaddrInet4{Port: int(port)} copy(sa.Addr[:], dut.Net.LocalIPv4) - if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != unix.EINPROGRESS { + if _, err := dut.ConnectWithErrno(context.Background(), t, clientFD, &sa); err != unix.EINPROGRESS { t.Errorf("got connect() = %v, want EINPROGRESS", err) } @@ -86,14 +84,12 @@ func TestTCPSynSentUnreachable6(t *testing.T) { defer conn.Close(t) // Bring the DUT to SYN-SENT state with a non-blocking connect. - ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) - defer cancel() sa := unix.SockaddrInet6{ Port: int(conn.SrcPort()), ZoneId: dut.Net.RemoteDevID, } copy(sa.Addr[:], dut.Net.LocalIPv6) - if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != unix.EINPROGRESS { + if _, err := dut.ConnectWithErrno(context.Background(), t, clientFD, &sa); err != unix.EINPROGRESS { t.Errorf("got connect() = %v, want EINPROGRESS", err) } diff --git a/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go index 1c8b72ebe..974c15384 100644 --- a/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go +++ b/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go @@ -20,7 +20,6 @@ import ( "encoding/hex" "errors" "flag" - "sync" "testing" "time" @@ -54,37 +53,39 @@ func TestQueueSendInSynSentHandshake(t *testing.T) { // Test blocking send. dut.SetNonBlocking(t, socket, false) - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - var block sync.WaitGroup - block.Add(1) + start := make(chan struct{}) + done := make(chan struct{}) go func() { - defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() + defer close(done) - block.Done() + close(start) // Issue SEND call in SYN-SENT, this should be queued for // process until the connection is established. - n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0) - if n == -1 { + if _, err := dut.SendWithErrno(context.Background(), t, socket, sampleData, 0); err != unix.Errno(0) { t.Errorf("failed to send on DUT: %s", err) - return } }() // Wait for the goroutine to be scheduled and before it // blocks on endpoint send/receive. - block.Wait() + <-start // The following sleep is used to prevent the connection // from being established before we are blocked: there is // still a small time window between we sending the RPC // request and the system actually being blocked. time.Sleep(100 * time.Millisecond) + select { + case <-done: + t.Fatal("expected send to be blocked in SYN-SENT") + default: + } + // Bring the connection to Established. conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}) + + <-done + // Expect the data from the DUT's enqueued send request. // // On Linux, this can be piggybacked with the ACK completing the @@ -126,21 +127,16 @@ func TestQueueRecvInSynSentHandshake(t *testing.T) { // Test blocking read. dut.SetNonBlocking(t, socket, false) - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - var block sync.WaitGroup - block.Add(1) + start := make(chan struct{}) + done := make(chan struct{}) go func() { - defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() + defer close(done) - block.Done() + close(start) // Issue RECEIVE call in SYN-SENT, this should be queued for // process until the connection is established. - n, buff, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0) - if n == -1 { + n, buff, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0) + if err != unix.Errno(0) { t.Errorf("failed to recv on DUT: %s", err) return } @@ -151,7 +147,8 @@ func TestQueueRecvInSynSentHandshake(t *testing.T) { // Wait for the goroutine to be scheduled and before it // blocks on endpoint send/receive. - block.Wait() + <-start + // The following sleep is used to prevent the connection // from being established before we are blocked: there is // still a small time window between we sending the RPC @@ -169,6 +166,8 @@ func TestQueueRecvInSynSentHandshake(t *testing.T) { if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } + + <-done } // TestQueueSendInSynSentRST tests send behavior when the TCP state @@ -192,20 +191,15 @@ func TestQueueSendInSynSentRST(t *testing.T) { // Test blocking send. dut.SetNonBlocking(t, socket, false) - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - var block sync.WaitGroup - block.Add(1) + start := make(chan struct{}) + done := make(chan struct{}) go func() { - defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() + defer close(done) - block.Done() + close(start) // Issue SEND call in SYN-SENT, this should be queued for // process until the connection is established. - n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0) + n, err := dut.SendWithErrno(context.Background(), t, socket, sampleData, 0) if err != unix.ECONNREFUSED { t.Errorf("expected error %s, got %s", unix.ECONNREFUSED, err) } @@ -216,14 +210,23 @@ func TestQueueSendInSynSentRST(t *testing.T) { // Wait for the goroutine to be scheduled and before it // blocks on endpoint send/receive. - block.Wait() + <-start + // The following sleep is used to prevent the connection // from being established before we are blocked: there is // still a small time window between we sending the RPC // request and the system actually being blocked. time.Sleep(100 * time.Millisecond) + select { + case <-done: + t.Fatal("expected send to be blocked in SYN-SENT") + default: + } + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}) + + <-done } // TestQueueRecvInSynSentRST tests recv behavior when the TCP state @@ -251,20 +254,15 @@ func TestQueueRecvInSynSentRST(t *testing.T) { // Test blocking read. dut.SetNonBlocking(t, socket, false) - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - var block sync.WaitGroup - block.Add(1) + start := make(chan struct{}) + done := make(chan struct{}) go func() { - defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() + defer close(done) - block.Done() + close(start) // Issue RECEIVE call in SYN-SENT, this should be queued for // process until the connection is established. - n, _, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0) + n, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0) if err != unix.ECONNREFUSED { t.Errorf("expected error %s, got %s", unix.ECONNREFUSED, err) } @@ -275,7 +273,8 @@ func TestQueueRecvInSynSentRST(t *testing.T) { // Wait for the goroutine to be scheduled and before it // blocks on endpoint send/receive. - block.Wait() + <-start + // The following sleep is used to prevent the connection // from being established before we are blocked: there is // still a small time window between we sending the RPC @@ -283,4 +282,5 @@ func TestQueueRecvInSynSentRST(t *testing.T) { time.Sleep(100 * time.Millisecond) conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}) + <-done } diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index 087aeb66e..bb33ca4b3 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -141,8 +141,6 @@ func testRecv(ctx context.Context, t *testing.T, d testData) { d.conn.Send(t, testbench.UDP{}) if d.wantErrno != unix.Errno(0) { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0) if ret != -1 { t.Fatalf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) @@ -167,8 +165,6 @@ func testSendTo(ctx context.Context, t *testing.T, d testData) { } if d.wantErrno != unix.Errno(0) { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) if ret != -1 { @@ -315,10 +311,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { defer wg.Done() if wantErrno != unix.Errno(0) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0) + ret, _, err := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) if ret != -1 { t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) return @@ -329,10 +322,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { } } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - if ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0); ret == -1 { t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) } }() @@ -340,10 +330,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { go func() { defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - if ret, _, err := dut.RecvWithErrno(ctx, t, cleanFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(context.Background(), t, cleanFD, 100, 0); ret == -1 { t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) } }() diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 99743b14a..8fc96d7ba 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -648,6 +648,12 @@ syscall_test( syscall_test( size = "large", shard_count = most_shards, + test = "//test/syscalls/linux:socket_inet_loopback_isolated_test", +) + +syscall_test( + size = "large", + shard_count = most_shards, # Takes too long for TSAN. Creates a lot of TCP sockets. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test", diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index d8b562e9d..c7991cfaa 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -9,6 +9,8 @@ exports_files( [ "socket.cc", "socket_inet_loopback.cc", + "socket_inet_loopback_isolated.cc", + "socket_inet_loopback_test_params.h", "socket_ip_loopback_blocking.cc", "socket_ip_tcp_generic_loopback.cc", "socket_ip_tcp_loopback.cc", @@ -1883,6 +1885,7 @@ cc_binary( linkstatic = 1, deps = [ "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", gtest, "//test/util:capability_util", @@ -3135,6 +3138,16 @@ cc_binary( ], ) +cc_library( + name = "socket_inet_loopback_test_params", + testonly = 1, + hdrs = ["socket_inet_loopback_test_params.h"], + deps = [ + ":socket_test_util", + gtest, + ], +) + cc_binary( name = "socket_inet_loopback_test", testonly = 1, @@ -3142,6 +3155,7 @@ cc_binary( linkstatic = 1, deps = [ ":ip_socket_test_util", + ":socket_inet_loopback_test_params", ":socket_test_util", "//test/util:file_descriptor", "@com_google_absl//absl/memory", @@ -3163,16 +3177,29 @@ cc_binary( linkstatic = 1, deps = [ ":ip_socket_test_util", + ":socket_inet_loopback_test_params", ":socket_test_util", "//test/util:file_descriptor", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", gtest, "//test/util:posix_error", "//test/util:save_util", "//test/util:test_main", "//test/util:test_util", - "//test/util:thread_util", + ], +) + +cc_binary( + name = "socket_inet_loopback_isolated_test", + testonly = 1, + srcs = ["socket_inet_loopback_isolated.cc"], + linkstatic = 1, + deps = [ + ":socket_inet_loopback_test_params", + ":socket_test_util", + gtest, + "//test/util:test_main", + "@com_google_absl//absl/time", ], ) diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index d519b65e6..f64c23ac0 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -30,6 +30,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/flags/flag.h" +#include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/capability_util.h" @@ -51,17 +52,10 @@ ABSL_FLAG(bool, ptrace_test_execve_child, false, ABSL_FLAG(bool, ptrace_test_trace_descendants_allowed, false, "If set, run the child workload for " "PtraceTest_TraceDescendantsAllowed."); -ABSL_FLAG(bool, ptrace_test_prctl_set_ptracer_pid, false, - "If set, run the child workload for PtraceTest_PrctlSetPtracerPID."); -ABSL_FLAG(bool, ptrace_test_prctl_set_ptracer_any, false, - "If set, run the child workload for PtraceTest_PrctlSetPtracerAny."); -ABSL_FLAG(bool, ptrace_test_prctl_clear_ptracer, false, - "If set, run the child workload for PtraceTest_PrctlClearPtracer."); -ABSL_FLAG(bool, ptrace_test_prctl_replace_ptracer, false, - "If set, run the child workload for PtraceTest_PrctlReplacePtracer."); -ABSL_FLAG(int, ptrace_test_prctl_replace_ptracer_tid, -1, - "Specifies the replacement tracer tid in the child workload for " - "PtraceTest_PrctlReplacePtracer."); +ABSL_FLAG(bool, ptrace_test_ptrace_attacher, false, + "If set, run the child workload for PtraceAttacherSubprocess."); +ABSL_FLAG(bool, ptrace_test_prctl_set_ptracer, false, + "If set, run the child workload for PrctlSetPtracerSubprocess."); ABSL_FLAG(bool, ptrace_test_prctl_set_ptracer_and_exit_tracee_thread, false, "If set, run the child workload for " "PtraceTest_PrctlSetPtracerPersistsPastTraceeThreadExit."); @@ -161,6 +155,86 @@ int CheckPtraceAttach(pid_t pid) { return 0; } +class SimpleSubprocess { + public: + explicit SimpleSubprocess(absl::string_view child_flag) { + int sockets[2]; + TEST_PCHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + + // Allocate vector before forking (not async-signal-safe). + ExecveArray const owned_child_argv = {"/proc/self/exe", child_flag, + "--ptrace_test_fd", + std::to_string(sockets[0])}; + char* const* const child_argv = owned_child_argv.get(); + + pid_ = fork(); + if (pid_ == 0) { + TEST_PCHECK(close(sockets[1]) == 0); + execve(child_argv[0], child_argv, /* envp = */ nullptr); + TEST_PCHECK_MSG(false, "Survived execve to test child"); + } + TEST_PCHECK(pid_ > 0); + TEST_PCHECK(close(sockets[0]) == 0); + sockfd_ = sockets[1]; + } + + SimpleSubprocess(SimpleSubprocess&& orig) + : pid_(orig.pid_), sockfd_(orig.sockfd_) { + orig.pid_ = -1; + orig.sockfd_ = -1; + } + + SimpleSubprocess& operator=(SimpleSubprocess&& orig) { + if (this != &orig) { + this->~SimpleSubprocess(); + pid_ = orig.pid_; + sockfd_ = orig.sockfd_; + orig.pid_ = -1; + orig.sockfd_ = -1; + } + return *this; + } + + SimpleSubprocess(SimpleSubprocess const&) = delete; + SimpleSubprocess& operator=(SimpleSubprocess const&) = delete; + + ~SimpleSubprocess() { + if (pid_ < 0) { + return; + } + EXPECT_THAT(shutdown(sockfd_, SHUT_RDWR), SyscallSucceeds()); + EXPECT_THAT(close(sockfd_), SyscallSucceeds()); + int status; + EXPECT_THAT(waitpid(pid_, &status, 0), SyscallSucceedsWithValue(pid_)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) + << " status " << status; + } + + pid_t pid() const { return pid_; } + + // Sends the child process the given value, receives an errno in response, and + // returns a PosixError corresponding to the received errno. + template <typename T> + PosixError Cmd(T val) { + if (WriteFd(sockfd_, &val, sizeof(val)) < 0) { + return PosixError(errno, "write failed"); + } + return RecvErrno(); + } + + private: + PosixError RecvErrno() { + int resp_errno; + if (ReadFd(sockfd_, &resp_errno, sizeof(resp_errno)) < 0) { + return PosixError(errno, "read failed"); + } + return PosixError(resp_errno); + } + + pid_t pid_ = -1; + int sockfd_ = -1; +}; + TEST(PtraceTest, AttachSelf) { EXPECT_THAT(ptrace(PTRACE_ATTACH, gettid(), 0, 0), SyscallFailsWithErrno(EPERM)); @@ -343,289 +417,128 @@ TEST(PtraceTest, PrctlSetPtracerInvalidPID) { EXPECT_THAT(prctl(PR_SET_PTRACER, 123456789), SyscallFailsWithErrno(EINVAL)); } -TEST(PtraceTest, PrctlSetPtracerPID) { - SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - - AutoCapability cap(CAP_SYS_PTRACE, false); - - // Use sockets to synchronize between tracer and tracee. - int sockets[2]; - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets), SyscallSucceeds()); - - // Allocate vector before forking (not async-signal-safe). - ExecveArray const owned_child_argv = { - "/proc/self/exe", "--ptrace_test_prctl_set_ptracer_pid", - "--ptrace_test_fd", std::to_string(sockets[0])}; - char* const* const child_argv = owned_child_argv.get(); - - pid_t const tracee_pid = fork(); - if (tracee_pid == 0) { - TEST_PCHECK(close(sockets[1]) == 0); - // This test will create a new thread in the child process. - // pthread_create(2) isn't async-signal-safe, so we execve() first. - execve(child_argv[0], child_argv, /* envp = */ nullptr); - TEST_PCHECK_MSG(false, "Survived execve to test child"); - } - ASSERT_THAT(tracee_pid, SyscallSucceeds()); - ASSERT_THAT(close(sockets[0]), SyscallSucceeds()); - - pid_t const tracer_pid = fork(); - if (tracer_pid == 0) { - // Wait until tracee has called prctl. - char done; - TEST_PCHECK(read(sockets[1], &done, 1) == 1); - MaybeSave(); - - TEST_PCHECK(CheckPtraceAttach(tracee_pid) == 0); - _exit(0); - } - ASSERT_THAT(tracer_pid, SyscallSucceeds()); - - // Clean up tracer. - int status; - ASSERT_THAT(waitpid(tracer_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); - - // Clean up tracee. - ASSERT_THAT(kill(tracee_pid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(tracee_pid, &status, 0), - SyscallSucceedsWithValue(tracee_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; +SimpleSubprocess CreatePtraceAttacherSubprocess() { + return SimpleSubprocess("--ptrace_test_ptrace_attacher"); } -[[noreturn]] void RunPrctlSetPtracerPID(int fd) { - ScopedThread t([fd] { - // Perform prctl in a separate thread to verify that it is process-wide. - TEST_PCHECK(prctl(PR_SET_PTRACER, getppid()) == 0); - MaybeSave(); - // Indicate that the prctl has been set. - TEST_PCHECK(write(fd, "x", 1) == 1); - MaybeSave(); +[[noreturn]] static void RunPtraceAttacher(int sockfd) { + // execve() may have restored CAP_SYS_PTRACE if we had real UID 0. + TEST_CHECK(SetCapability(CAP_SYS_PTRACE, false).ok()); + // Perform PTRACE_ATTACH in a separate thread to verify that permissions + // apply process-wide. + ScopedThread t([&] { + while (true) { + pid_t pid; + int rv = read(sockfd, &pid, sizeof(pid)); + if (rv == 0) { + _exit(0); + } + if (rv < 0) { + _exit(1); + } + int resp_errno = 0; + if (CheckPtraceAttach(pid) < 0) { + resp_errno = errno; + } + TEST_PCHECK(write(sockfd, &resp_errno, sizeof(resp_errno)) == + sizeof(resp_errno)); + } }); while (true) { SleepSafe(absl::Seconds(1)); } } -TEST(PtraceTest, PrctlSetPtracerAny) { - SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - AutoCapability cap(CAP_SYS_PTRACE, false); - - // Use sockets to synchronize between tracer and tracee. - int sockets[2]; - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets), SyscallSucceeds()); - - // Allocate vector before forking (not async-signal-safe). - ExecveArray const owned_child_argv = { - "/proc/self/exe", "--ptrace_test_prctl_set_ptracer_any", - "--ptrace_test_fd", std::to_string(sockets[0])}; - char* const* const child_argv = owned_child_argv.get(); - - pid_t const tracee_pid = fork(); - if (tracee_pid == 0) { - // This test will create a new thread in the child process. - // pthread_create(2) isn't async-signal-safe, so we execve() first. - TEST_PCHECK(close(sockets[1]) == 0); - execve(child_argv[0], child_argv, /* envp = */ nullptr); - TEST_PCHECK_MSG(false, "Survived execve to test child"); - } - ASSERT_THAT(tracee_pid, SyscallSucceeds()); - ASSERT_THAT(close(sockets[0]), SyscallSucceeds()); - - pid_t const tracer_pid = fork(); - if (tracer_pid == 0) { - // Wait until tracee has called prctl. - char done; - TEST_PCHECK(read(sockets[1], &done, 1) == 1); - MaybeSave(); - - TEST_PCHECK(CheckPtraceAttach(tracee_pid) == 0); - _exit(0); - } - ASSERT_THAT(tracer_pid, SyscallSucceeds()); - - // Clean up tracer. - int status; - ASSERT_THAT(waitpid(tracer_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; - - // Clean up tracee. - ASSERT_THAT(kill(tracee_pid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(tracee_pid, &status, 0), - SyscallSucceedsWithValue(tracee_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; +SimpleSubprocess CreatePrctlSetPtracerSubprocess() { + return SimpleSubprocess("--ptrace_test_prctl_set_ptracer"); } -[[noreturn]] void RunPrctlSetPtracerAny(int fd) { - ScopedThread t([fd] { - // Perform prctl in a separate thread to verify that it is process-wide. - TEST_PCHECK(prctl(PR_SET_PTRACER, PR_SET_PTRACER_ANY) == 0); - MaybeSave(); - // Indicate that the prctl has been set. - TEST_PCHECK(write(fd, "x", 1) == 1); - MaybeSave(); +[[noreturn]] static void RunPrctlSetPtracer(int sockfd) { + // Perform prctl in a separate thread to verify that it applies + // process-wide. + ScopedThread t([&] { + while (true) { + pid_t pid; + int rv = read(sockfd, &pid, sizeof(pid)); + if (rv == 0) { + _exit(0); + } + if (rv < 0) { + _exit(1); + } + int resp_errno = 0; + if (prctl(PR_SET_PTRACER, pid) < 0) { + resp_errno = errno; + } + TEST_PCHECK(write(sockfd, &resp_errno, sizeof(resp_errno)) == + sizeof(resp_errno)); + } }); while (true) { SleepSafe(absl::Seconds(1)); } } -TEST(PtraceTest, PrctlClearPtracer) { +TEST(PtraceTest, PrctlSetPtracer) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - AutoCapability cap(CAP_SYS_PTRACE, false); - - // Use sockets to synchronize between tracer and tracee. - int sockets[2]; - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets), SyscallSucceeds()); - - // Allocate vector before forking (not async-signal-safe). - ExecveArray const owned_child_argv = { - "/proc/self/exe", "--ptrace_test_prctl_clear_ptracer", "--ptrace_test_fd", - std::to_string(sockets[0])}; - char* const* const child_argv = owned_child_argv.get(); - - pid_t const tracee_pid = fork(); - if (tracee_pid == 0) { - // This test will create a new thread in the child process. - // pthread_create(2) isn't async-signal-safe, so we execve() first. - TEST_PCHECK(close(sockets[1]) == 0); - execve(child_argv[0], child_argv, /* envp = */ nullptr); - TEST_PCHECK_MSG(false, "Survived execve to test child"); - } - ASSERT_THAT(tracee_pid, SyscallSucceeds()); - ASSERT_THAT(close(sockets[0]), SyscallSucceeds()); - - pid_t const tracer_pid = fork(); - if (tracer_pid == 0) { - // Wait until tracee has called prctl. - char done; - TEST_PCHECK(read(sockets[1], &done, 1) == 1); - MaybeSave(); - - TEST_CHECK(CheckPtraceAttach(tracee_pid) == -1); - TEST_PCHECK(errno == EPERM); - _exit(0); - } - ASSERT_THAT(tracer_pid, SyscallSucceeds()); - - // Clean up tracer. - int status; - ASSERT_THAT(waitpid(tracer_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; - - // Clean up tracee. - ASSERT_THAT(kill(tracee_pid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(tracee_pid, &status, 0), - SyscallSucceedsWithValue(tracee_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; -} - -[[noreturn]] void RunPrctlClearPtracer(int fd) { - ScopedThread t([fd] { - // Perform prctl in a separate thread to verify that it is process-wide. - TEST_PCHECK(prctl(PR_SET_PTRACER, PR_SET_PTRACER_ANY) == 0); - MaybeSave(); - TEST_PCHECK(prctl(PR_SET_PTRACER, 0) == 0); - MaybeSave(); - // Indicate that the prctl has been set/cleared. - TEST_PCHECK(write(fd, "x", 1) == 1); - MaybeSave(); - }); - while (true) { - SleepSafe(absl::Seconds(1)); - } -} -TEST(PtraceTest, PrctlReplacePtracer) { - SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); AutoCapability cap(CAP_SYS_PTRACE, false); - pid_t const unused_pid = fork(); - if (unused_pid == 0) { - while (true) { - SleepSafe(absl::Seconds(1)); - } - } - ASSERT_THAT(unused_pid, SyscallSucceeds()); + // Ensure that initially, no tracer exception is set. + ASSERT_THAT(prctl(PR_SET_PTRACER, 0), SyscallSucceeds()); - // Use sockets to synchronize between tracer and tracee. - int sockets[2]; - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets), SyscallSucceeds()); + SimpleSubprocess tracee = CreatePrctlSetPtracerSubprocess(); + SimpleSubprocess tracer = CreatePtraceAttacherSubprocess(); - // Allocate vector before forking (not async-signal-safe). - ExecveArray const owned_child_argv = { - "/proc/self/exe", - "--ptrace_test_prctl_replace_ptracer", - "--ptrace_test_prctl_replace_ptracer_tid", - std::to_string(unused_pid), - "--ptrace_test_fd", - std::to_string(sockets[0])}; - char* const* const child_argv = owned_child_argv.get(); + // By default, Yama should prevent tracer from tracing its parent (this + // process) or siblings (tracee). + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(EPERM)); + EXPECT_THAT(tracer.Cmd(tracee.pid()), PosixErrorIs(EPERM)); - pid_t const tracee_pid = fork(); - if (tracee_pid == 0) { - TEST_PCHECK(close(sockets[1]) == 0); - // This test will create a new thread in the child process. - // pthread_create(2) isn't async-signal-safe, so we execve() first. - execve(child_argv[0], child_argv, /* envp = */ nullptr); - TEST_PCHECK_MSG(false, "Survived execve to test child"); - } - ASSERT_THAT(tracee_pid, SyscallSucceeds()); - ASSERT_THAT(close(sockets[0]), SyscallSucceeds()); + // If tracee invokes PR_SET_PTRACER on either tracer's pid, the pid of any of + // its ancestors (i.e. us), or PR_SET_PTRACER_ANY, then tracer can trace it + // (but not us). - pid_t const tracer_pid = fork(); - if (tracer_pid == 0) { - // Wait until tracee has called prctl. - char done; - TEST_PCHECK(read(sockets[1], &done, 1) == 1); - MaybeSave(); + ASSERT_THAT(tracee.Cmd(tracer.pid()), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(tracee.pid()), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(EPERM)); - TEST_CHECK(CheckPtraceAttach(tracee_pid) == -1); - TEST_PCHECK(errno == EPERM); - _exit(0); - } - ASSERT_THAT(tracer_pid, SyscallSucceeds()); + ASSERT_THAT(tracee.Cmd(gettid()), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(tracee.pid()), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(EPERM)); - // Clean up tracer. - int status; - ASSERT_THAT(waitpid(tracer_pid, &status, 0), SyscallSucceeds()); - EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) - << " status " << status; + ASSERT_THAT(tracee.Cmd(static_cast<pid_t>(PR_SET_PTRACER_ANY)), + PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(tracee.pid()), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(EPERM)); - // Clean up tracee. - ASSERT_THAT(kill(tracee_pid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(tracee_pid, &status, 0), - SyscallSucceedsWithValue(tracee_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; + // If tracee invokes PR_SET_PTRACER with pid 0, then tracer can no longer + // trace it. + ASSERT_THAT(tracee.Cmd(0), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(tracee.pid()), PosixErrorIs(EPERM)); - // Clean up unused. - ASSERT_THAT(kill(unused_pid, SIGKILL), SyscallSucceeds()); - ASSERT_THAT(waitpid(unused_pid, &status, 0), - SyscallSucceedsWithValue(unused_pid)); - EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGKILL) - << " status " << status; -} + // If we invoke PR_SET_PTRACER with tracer's pid, then it can trace us (but + // not our descendants). + ASSERT_THAT(prctl(PR_SET_PTRACER, tracer.pid()), SyscallSucceeds()); + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(tracee.pid()), PosixErrorIs(EPERM)); -[[noreturn]] void RunPrctlReplacePtracer(int new_tracer_pid, int fd) { - TEST_PCHECK(prctl(PR_SET_PTRACER, getppid()) == 0); - MaybeSave(); + // If we invoke PR_SET_PTRACER with pid 0, then tracer can no longer trace us. + ASSERT_THAT(prctl(PR_SET_PTRACER, 0), SyscallSucceeds()); + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(EPERM)); - ScopedThread t([new_tracer_pid, fd] { - TEST_PCHECK(prctl(PR_SET_PTRACER, new_tracer_pid) == 0); - MaybeSave(); - // Indicate that the prctl has been set. - TEST_PCHECK(write(fd, "x", 1) == 1); - MaybeSave(); - }); - while (true) { - SleepSafe(absl::Seconds(1)); - } + // Another thread in our thread group can invoke PR_SET_PTRACER instead; its + // effect applies to the whole thread group. + pid_t const our_tid = gettid(); + ScopedThread([&] { + ASSERT_THAT(prctl(PR_SET_PTRACER, tracer.pid()), SyscallSucceeds()); + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(0)); + EXPECT_THAT(tracer.Cmd(our_tid), PosixErrorIs(0)); + + ASSERT_THAT(prctl(PR_SET_PTRACER, 0), SyscallSucceeds()); + EXPECT_THAT(tracer.Cmd(gettid()), PosixErrorIs(EPERM)); + EXPECT_THAT(tracer.Cmd(our_tid), PosixErrorIs(EPERM)); + }).Join(); } // Tests that YAMA exceptions store tracees by thread group leader. Exceptions @@ -2342,21 +2255,12 @@ int main(int argc, char** argv) { gvisor::testing::RunTraceDescendantsAllowed(fd); } - if (absl::GetFlag(FLAGS_ptrace_test_prctl_set_ptracer_pid)) { - gvisor::testing::RunPrctlSetPtracerPID(fd); - } - - if (absl::GetFlag(FLAGS_ptrace_test_prctl_set_ptracer_any)) { - gvisor::testing::RunPrctlSetPtracerAny(fd); - } - - if (absl::GetFlag(FLAGS_ptrace_test_prctl_clear_ptracer)) { - gvisor::testing::RunPrctlClearPtracer(fd); + if (absl::GetFlag(FLAGS_ptrace_test_ptrace_attacher)) { + gvisor::testing::RunPtraceAttacher(fd); } - if (absl::GetFlag(FLAGS_ptrace_test_prctl_replace_ptracer)) { - gvisor::testing::RunPrctlReplacePtracer( - absl::GetFlag(FLAGS_ptrace_test_prctl_replace_ptracer_tid), fd); + if (absl::GetFlag(FLAGS_ptrace_test_prctl_set_ptracer)) { + gvisor::testing::RunPrctlSetPtracer(fd); } if (absl::GetFlag( diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc index 3b108cbd3..70b0b2742 100644 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -77,34 +77,6 @@ class BindToDeviceDistributionTest } }; -PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { - switch (family) { - case AF_INET: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); - case AF_INET6: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { - switch (family) { - case AF_INET: - reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; - return NoError(); - case AF_INET6: - reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; - return NoError(); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - // Binds sockets to different devices and then creates many TCP connections. // Checks that the distribution of connections received on the sockets matches // the expectation. diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 6b369d5b7..badc42ec5 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -34,6 +34,7 @@ #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_inet_loopback_test_params.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -48,45 +49,7 @@ namespace { using ::testing::Gt; -PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { - switch (family) { - case AF_INET: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); - case AF_INET6: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { - switch (family) { - case AF_INET: - reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; - return NoError(); - case AF_INET6: - reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; - return NoError(); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -struct TestParam { - TestAddress listener; - TestAddress connector; -}; - -std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) { - return absl::StrCat("Listen", info.param.listener.description, "_Connect", - info.param.connector.description); -} - -using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; +using SocketInetLoopbackTest = ::testing::TestWithParam<SocketInetTestParam>; TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { int fd[2] = {}; @@ -299,7 +262,7 @@ void tcpSimpleConnectTest(TestAddress const& listener, } TEST_P(SocketInetLoopbackTest, TCP) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -307,7 +270,7 @@ TEST_P(SocketInetLoopbackTest, TCP) { } TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -316,7 +279,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { } TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { - const auto& param = GetParam(); + SocketInetTestParam const& param = GetParam(); const TestAddress& listener = param.listener; const TestAddress& connector = param.connector; @@ -362,7 +325,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { } TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -430,7 +393,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { } TEST_P(SocketInetLoopbackTest, TCPListenClose) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -477,7 +440,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { // Test the protocol state information returned by TCPINFO. TEST_P(SocketInetLoopbackTest, TCPInfoState) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -546,7 +509,7 @@ TEST_P(SocketInetLoopbackTest, TCPInfoState) { ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); } -void TestHangupDuringConnect(const TestParam& param, +void TestHangupDuringConnect(const SocketInetTestParam& param, void (*hangup)(FileDescriptor&)) { TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -609,7 +572,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownDuringConnect) { }); } -void TestListenHangupConnectingRead(const TestParam& param, +void TestListenHangupConnectingRead(const SocketInetTestParam& param, void (*hangup)(FileDescriptor&)) { constexpr int kTimeout = 10000; @@ -718,7 +681,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownConnectingRead) { // Test close of a non-blocking connecting socket. TEST_P(SocketInetLoopbackTest, TCPNonBlockingConnectClose) { - TestParam const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -793,7 +756,7 @@ TEST_P(SocketInetLoopbackTest, TCPNonBlockingConnectClose) { // queue because the queue is full are not correctly delivered after restore // causing the last accept to timeout on the restore. TEST_P(SocketInetLoopbackTest, TCPAcceptBacklogSizes) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -843,7 +806,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptBacklogSizes) { // queue because the queue is full are not correctly delivered after restore // causing the last accept to timeout on the restore. TEST_P(SocketInetLoopbackTest, TCPBacklog) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -934,7 +897,7 @@ TEST_P(SocketInetLoopbackTest, TCPBacklog) { // queue because the queue is full are not correctly delivered after restore // causing the last accept to timeout on the restore. TEST_P(SocketInetLoopbackTest, TCPBacklogAcceptAll) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1024,175 +987,12 @@ TEST_P(SocketInetLoopbackTest, TCPBacklogAcceptAll) { } } -// TCPFinWait2Test creates a pair of connected sockets then closes one end to -// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local -// IP/port on a new socket and tries to connect. The connect should fail w/ -// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the -// connect again with a new socket and this time it should succeed. -// -// TCP timers are not S/R today, this can cause this test to be flaky when run -// under random S/R due to timer being reset on a restore. -TEST_P(SocketInetLoopbackTest, TCPFinWait2Test) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT( - bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - // Lower FIN_WAIT2 state to 5 seconds for test. - constexpr int kTCPLingerTimeout = 5; - EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, - &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), - SyscallSucceedsWithValue(0)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; - ASSERT_THAT( - getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen), - SyscallSucceeds()); - - // close the connecting FD to trigger FIN_WAIT2 on the connected fd. - conn_fd.reset(); - - // Now bind and connect a new socket. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - // Disable cooperative saves after this point. As a save between the first - // bind/connect and the second one can cause the linger timeout timer to - // be restarted causing the final bind/connect to fail. - DisableSave ds; - - ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); - - // Sleep for a little over the linger timeout to reduce flakiness in - // save/restore tests. - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); - - ds.reset(); - - ASSERT_THAT( - RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen), - SyscallSucceeds()); -} - -// TCPLinger2TimeoutAfterClose creates a pair of connected sockets -// then closes one end to trigger FIN_WAIT2 state for the closed endpont. -// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ -// connecting the same address succeeds. -// -// TCP timers are not S/R today, this can cause this test to be flaky when run -// under random S/R due to timer being reset on a restore. -TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - - // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT( - bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), - connector.addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; - ASSERT_THAT( - getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen), - SyscallSucceeds()); - - // Disable cooperative saves after this point as TCP timers are not restored - // across a S/R. - { - DisableSave ds; - constexpr int kTCPLingerTimeout = 5; - EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, - &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), - SyscallSucceedsWithValue(0)); - - // close the connecting FD to trigger FIN_WAIT2 on the connected fd. - conn_fd.reset(); - - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); - - // ds going out of scope will Re-enable S/R's since at this point the timer - // must have fired and cleaned up the endpoint. - } - - // Now bind and connect a new socket and verify that we can immediately - // rebind the address bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen), - SyscallSucceeds()); - ASSERT_THAT( - RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen), - SyscallSucceeds()); -} - // TCPResetAfterClose creates a pair of connected sockets then closes // one end to trigger FIN_WAIT2 state for the closed endpoint verifies // that we generate RSTs for any new data after the socket is fully // closed. TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1252,198 +1052,8 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { SyscallSucceedsWithValue(0)); } -// setupTimeWaitClose sets up a socket endpoint in TIME_WAIT state. -// Callers can choose to perform active close on either ends of the connection -// and also specify if they want to enabled SO_REUSEADDR. -void setupTimeWaitClose(const TestAddress* listener, - const TestAddress* connector, bool reuse, - bool accept_close, sockaddr_storage* listen_addr, - sockaddr_storage* conn_bound_addr) { - // Create the listening socket. - FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener->family(), SOCK_STREAM, IPPROTO_TCP)); - if (reuse) { - ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - } - ASSERT_THAT( - bind(listen_fd.get(), AsSockAddr(listen_addr), listener->addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Get the port bound by the listening socket. - socklen_t addrlen = listener->addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(listen_addr), &addrlen), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener->family(), *listen_addr)); - - // Connect to the listening socket. - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector->family(), SOCK_STREAM, IPPROTO_TCP)); - - // We disable saves after this point as a S/R causes the netstack seed - // to be regenerated which changes what ports/ISN is picked for a given - // tuple (src ip,src port, dst ip, dst port). This can cause the final - // SYN to use a sequence number that looks like one from the current - // connection in TIME_WAIT and will not be accepted causing the test - // to timeout. - // - // TODO(gvisor.dev/issue/940): S/R portSeed/portHint - DisableSave ds; - - sockaddr_storage conn_addr = connector->addr; - ASSERT_NO_ERRNO(SetAddrPort(connector->family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), - connector->addr_len), - SyscallSucceeds()); - - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Get the address/port bound by the connecting socket. - socklen_t conn_addrlen = connector->addr_len; - ASSERT_THAT( - getsockname(conn_fd.get(), AsSockAddr(conn_bound_addr), &conn_addrlen), - SyscallSucceeds()); - - FileDescriptor active_closefd, passive_closefd; - if (accept_close) { - active_closefd = std::move(accepted); - passive_closefd = std::move(conn_fd); - } else { - active_closefd = std::move(conn_fd); - passive_closefd = std::move(accepted); - } - - // shutdown to trigger TIME_WAIT. - ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds()); - { - constexpr int kTimeout = 10000; - pollfd pfd = { - .fd = passive_closefd.get(), - .events = POLLIN, - }; - ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - ASSERT_EQ(pfd.revents, POLLIN); - } - ASSERT_THAT(shutdown(passive_closefd.get(), SHUT_WR), SyscallSucceeds()); - { - constexpr int kTimeout = 10000; - constexpr int16_t want_events = POLLHUP; - pollfd pfd = { - .fd = active_closefd.get(), - .events = want_events, - }; - ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - } - - // This sleep is needed to reduce flake to ensure that the passive-close - // ensures the state transitions to CLOSE from LAST_ACK. - absl::SleepFor(absl::Seconds(1)); -} - -// These tests are disabled under random save as the the restore run -// results in the stack.Seed() being different which can cause -// sequence number of final connect to be one that is considered -// old and can cause the test to be flaky. -// -// Test re-binding of client and server bound addresses when the older -// connection is in TIME_WAIT. -TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest) { - auto const& param = GetParam(); - sockaddr_storage listen_addr, conn_bound_addr; - listen_addr = param.listener.addr; - setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, - true /*accept_close*/, &listen_addr, &conn_bound_addr); - - // Now bind a new socket and verify that we can immediately rebind the address - // bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), - param.connector.addr_len), - SyscallSucceeds()); - - FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT( - bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitReuseTest) { - auto const& param = GetParam(); - sockaddr_storage listen_addr, conn_bound_addr; - listen_addr = param.listener.addr; - setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, - true /*accept_close*/, &listen_addr, &conn_bound_addr); - - FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT( - bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); - - // Now bind and connect new socket and verify that we can immediately rebind - // the address bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), - param.connector.addr_len), - SyscallSucceeds()); - - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(param.listener.family(), listen_addr)); - sockaddr_storage conn_addr = param.connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(param.connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), - param.connector.addr_len), - SyscallSucceeds()); -} - -TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest) { - auto const& param = GetParam(); - sockaddr_storage listen_addr, conn_bound_addr; - listen_addr = param.listener.addr; - setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, - false /*accept_close*/, &listen_addr, &conn_bound_addr); - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), - param.connector.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest) { - auto const& param = GetParam(); - sockaddr_storage listen_addr, conn_bound_addr; - listen_addr = param.listener.addr; - setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, - false /*accept_close*/, &listen_addr, &conn_bound_addr); - FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), - param.connector.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); -} - TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1495,7 +1105,7 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { } TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1606,7 +1216,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAccept) { // saved. Enable S/R issue is fixed. DisableSave ds; - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1686,7 +1296,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout) { // saved. Enable S/R once issue is fixed. DisableSave ds; - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1753,42 +1363,16 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout) { ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); } -INSTANTIATE_TEST_SUITE_P( - All, SocketInetLoopbackTest, - ::testing::Values( - // Listeners bound to IPv4 addresses refuse connections using IPv6 - // addresses. - TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()}, - TestParam{V4Any(), V4MappedAny()}, - TestParam{V4Any(), V4MappedLoopback()}, - TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()}, - TestParam{V4Loopback(), V4MappedLoopback()}, - TestParam{V4MappedAny(), V4Any()}, - TestParam{V4MappedAny(), V4Loopback()}, - TestParam{V4MappedAny(), V4MappedAny()}, - TestParam{V4MappedAny(), V4MappedLoopback()}, - TestParam{V4MappedLoopback(), V4Any()}, - TestParam{V4MappedLoopback(), V4Loopback()}, - TestParam{V4MappedLoopback(), V4MappedLoopback()}, - - // Listeners bound to IN6ADDR_ANY accept all connections. - TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()}, - TestParam{V6Any(), V4MappedAny()}, - TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()}, - TestParam{V6Any(), V6Loopback()}, - - // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 - // addresses. - TestParam{V6Loopback(), V6Any()}, - TestParam{V6Loopback(), V6Loopback()}), - DescribeTestParam); +INSTANTIATE_TEST_SUITE_P(All, SocketInetLoopbackTest, + SocketInetLoopbackTestValues(), + DescribeSocketInetTestParam); -using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>; +using SocketInetReusePortTest = ::testing::TestWithParam<SocketInetTestParam>; // TODO(gvisor.dev/issue/940): Remove when portHint/stack.Seed is // saved/restored. TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1898,7 +1482,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { } TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -2009,7 +1593,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { } TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -2117,32 +1701,23 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( // Listeners bound to IPv4 addresses refuse connections using IPv6 // addresses. - TestParam{V4Any(), V4Loopback()}, - TestParam{V4Loopback(), V4MappedLoopback()}, + SocketInetTestParam{V4Any(), V4Loopback()}, + SocketInetTestParam{V4Loopback(), V4MappedLoopback()}, // Listeners bound to IN6ADDR_ANY accept all connections. - TestParam{V6Any(), V4Loopback()}, TestParam{V6Any(), V6Loopback()}, + SocketInetTestParam{V6Any(), V4Loopback()}, + SocketInetTestParam{V6Any(), V6Loopback()}, // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 // addresses. - TestParam{V6Loopback(), V6Loopback()}), - DescribeTestParam); - -struct ProtocolTestParam { - std::string description; - int type; -}; - -std::string DescribeProtocolTestParam( - ::testing::TestParamInfo<ProtocolTestParam> const& info) { - return info.param.description; -} + SocketInetTestParam{V6Loopback(), V6Loopback()}), + DescribeSocketInetTestParam); using SocketMultiProtocolInetLoopbackTest = ::testing::TestWithParam<ProtocolTestParam>; TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v4 loopback on a dual stack socket. @@ -2191,7 +1766,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) { } TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v4 any on a dual stack socket. @@ -2240,7 +1815,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) { } TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); // Bind the v6 any on a dual stack socket. TestAddress const& test_addr_dual = V6Any(); @@ -2303,7 +1878,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReuseAddrDoesNotReserveV4Any) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); // Bind the v6 any on a dual stack socket. TestAddress const& test_addr_dual = V6Any(); @@ -2340,7 +1915,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReuseAddrListenReservesV4Any) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); // Only TCP sockets are supported. SKIP_IF((param.type & SOCK_STREAM) == 0); @@ -2383,7 +1958,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyWithListenReservesEverything) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); // Only TCP sockets are supported. SKIP_IF((param.type & SOCK_STREAM) == 0); @@ -2450,7 +2025,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, } TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v6 any on a v6-only socket. @@ -2503,7 +2078,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { } TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v6 loopback on a dual stack socket. @@ -2583,66 +2158,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { } } -TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { - auto const& param = GetParam(); - - // Bind the v6 loopback on a dual stack socket. - TestAddress const& test_addr = V6Loopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), - SyscallSucceeds()); - ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), - bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is not reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - EXPECT_THAT( - bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), - SyscallSucceeds()); -} - TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v4 loopback on a dual stack socket. @@ -2754,68 +2271,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { } } -TEST_P(SocketMultiProtocolInetLoopbackTest, - V4MappedEphemeralPortReservedResueAddr) { - auto const& param = GetParam(); - - // Bind the v4 loopback on a dual stack socket. - TestAddress const& test_addr = V4MappedLoopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), - SyscallSucceeds()); - - ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), - bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is not reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - EXPECT_THAT( - bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), - SyscallSucceeds()); -} - TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v4 loopback on a v4 socket. @@ -2928,71 +2385,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { } } -TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { - auto const& param = GetParam(); - - // Bind the v4 loopback on a v4 socket. - TestAddress const& test_addr = V4Loopback(); - sockaddr_storage bound_addr = test_addr.addr; - const FileDescriptor bound_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - - ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), - SyscallSucceeds()); - - // Listen iff TCP. - if (param.type == SOCK_STREAM) { - ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); - } - - // Get the port that we bound. - socklen_t bound_addr_len = test_addr.addr_len; - ASSERT_THAT( - getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), - SyscallSucceeds()); - - // Connect to bind an ephemeral port. - const FileDescriptor connected_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - - ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), - bound_addr_len), - SyscallSucceeds()); - - // Get the ephemeral port. - sockaddr_storage connected_addr = {}; - socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), - &connected_addr_len), - SyscallSucceeds()); - uint16_t const ephemeral_port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); - - // Verify that we actually got an ephemeral port. - ASSERT_NE(ephemeral_port, 0); - - // Verify that the ephemeral port is not reserved. - const FileDescriptor checking_fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - EXPECT_THAT( - bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), - SyscallSucceeds()); -} - TEST_P(SocketMultiProtocolInetLoopbackTest, MultipleBindsAllowedNoListeningReuseAddr) { - const auto& param = GetParam(); + ProtocolTestParam const& param = GetParam(); // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP // this is only permitted if there is no other listening socket. SKIP_IF(param.type != SOCK_STREAM); @@ -3027,7 +2422,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, } TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); TestAddress const& test_addr = V4Loopback(); sockaddr_storage addr = test_addr.addr; @@ -3080,7 +2475,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { // closed, we can bind a different socket to the same address without needing // REUSEPORT. TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) { - auto const& param = GetParam(); + ProtocolTestParam const& param = GetParam(); TestAddress const& test_addr = V4Loopback(); sockaddr_storage addr = test_addr.addr; @@ -3107,11 +2502,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) { ASSERT_THAT(bind(fd, AsSockAddr(&addr), addrlen), SyscallSucceeds()); } -INSTANTIATE_TEST_SUITE_P( - AllFamilies, SocketMultiProtocolInetLoopbackTest, - ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, - ProtocolTestParam{"UDP", SOCK_DGRAM}), - DescribeProtocolTestParam); +INSTANTIATE_TEST_SUITE_P(AllFamilies, SocketMultiProtocolInetLoopbackTest, + ProtocolTestValues(), DescribeProtocolTestParam); } // namespace diff --git a/test/syscalls/linux/socket_inet_loopback_isolated.cc b/test/syscalls/linux/socket_inet_loopback_isolated.cc new file mode 100644 index 000000000..ccb016726 --- /dev/null +++ b/test/syscalls/linux/socket_inet_loopback_isolated.cc @@ -0,0 +1,488 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <netinet/tcp.h> + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/socket_inet_loopback_test_params.h" +#include "test/syscalls/linux/socket_test_util.h" + +// Unit tests in this file will run in their own network namespace. + +namespace gvisor { +namespace testing { + +namespace { + +using SocketInetLoopbackIsolatedTest = + ::testing::TestWithParam<SocketInetTestParam>; + +TEST_P(SocketInetLoopbackIsolatedTest, TCPActiveCloseTimeWaitTest) { + SocketInetTestParam const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + SetupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackIsolatedTest, TCPActiveCloseTimeWaitReuseTest) { + SocketInetTestParam const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + SetupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +// These tests are disabled under random save as the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +// +// Test re-binding of client and server bound addresses when the older +// connection is in TIME_WAIT. +TEST_P(SocketInetLoopbackIsolatedTest, TCPPassiveCloseNoTimeWaitTest) { + SocketInetTestParam const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + SetupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); + + // Now bind a new socket and verify that we can immediately rebind the address + // bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), + param.connector.addr_len), + SyscallSucceeds()); + + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackIsolatedTest, TCPPassiveCloseNoTimeWaitReuseTest) { + SocketInetTestParam const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + SetupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); + + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Now bind and connect new socket and verify that we can immediately rebind + // the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), + param.connector.addr_len), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(param.listener.family(), listen_addr)); + sockaddr_storage conn_addr = param.connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(param.connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), + param.connector.addr_len), + SyscallSucceeds()); +} + +// TCPFinWait2Test creates a pair of connected sockets then closes one end to +// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local +// IP/port on a new socket and tries to connect. The connect should fail w/ +// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the +// connect again with a new socket and this time it should succeed. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackIsolatedTest, TCPFinWait2Test) { + SocketInetTestParam const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Lower FIN_WAIT2 state to 5 seconds for test. + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen), + SyscallSucceeds()); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + // Now bind and connect a new socket. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Disable cooperative saves after this point. As a save between the first + // bind/connect and the second one can cause the linger timeout timer to + // be restarted causing the final bind/connect to fail. + DisableSave ds; + + ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + + // Sleep for a little over the linger timeout to reduce flakiness in + // save/restore tests. + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); + + ds.reset(); + + ASSERT_THAT( + RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen), + SyscallSucceeds()); +} + +// TCPLinger2TimeoutAfterClose creates a pair of connected sockets +// then closes one end to trigger FIN_WAIT2 state for the closed endpoint. +// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ +// connecting the same address succeeds. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackIsolatedTest, TCPLinger2TimeoutAfterClose) { + SocketInetTestParam const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen), + SyscallSucceeds()); + + // Disable cooperative saves after this point as TCP timers are not restored + // across a S/R. + { + DisableSave ds; + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // ds going out of scope will Re-enable S/R's since at this point the timer + // must have fired and cleaned up the endpoint. + } + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT( + RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen), + SyscallSucceeds()); +} + +INSTANTIATE_TEST_SUITE_P(All, SocketInetLoopbackIsolatedTest, + SocketInetLoopbackTestValues(), + DescribeSocketInetTestParam); + +using SocketMultiProtocolInetLoopbackIsolatedTest = + ::testing::TestWithParam<ProtocolTestParam>; + +TEST_P(SocketMultiProtocolInetLoopbackIsolatedTest, + V4EphemeralPortReservedReuseAddr) { + ProtocolTestParam const& param = GetParam(); + + // Bind the v4 loopback on a v4 socket. + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), + bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketMultiProtocolInetLoopbackIsolatedTest, + V4MappedEphemeralPortReservedReuseAddr) { + ProtocolTestParam const& param = GetParam(); + + // Bind the v4 loopback on a dual stack socket. + TestAddress const& test_addr = V4MappedLoopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), + bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketMultiProtocolInetLoopbackIsolatedTest, + V6EphemeralPortReservedReuseAddr) { + ProtocolTestParam const& param = GetParam(); + + // Bind the v6 loopback on a dual stack socket. + TestAddress const& test_addr = V6Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), + bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), + SyscallSucceeds()); +} + +INSTANTIATE_TEST_SUITE_P(AllFamilies, + SocketMultiProtocolInetLoopbackIsolatedTest, + ProtocolTestValues(), DescribeProtocolTestParam); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc index 601ae107b..b131213d4 100644 --- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -27,6 +27,7 @@ #include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_inet_loopback_test_params.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -38,47 +39,7 @@ namespace testing { namespace { -using ::testing::Gt; - -PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { - switch (family) { - case AF_INET: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); - case AF_INET6: - return static_cast<uint16_t>( - reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { - switch (family) { - case AF_INET: - reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; - return NoError(); - case AF_INET6: - reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; - return NoError(); - default: - return PosixError(EINVAL, - absl::StrCat("unknown socket family: ", family)); - } -} - -struct TestParam { - TestAddress listener; - TestAddress connector; -}; - -std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) { - return absl::StrCat("Listen", info.param.listener.description, "_Connect", - info.param.connector.description); -} - -using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; +using SocketInetLoopbackTest = ::testing::TestWithParam<SocketInetTestParam>; // This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral // ports are already in use for a given destination ip/port. @@ -87,7 +48,7 @@ using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; // // FIXME(b/162475855): This test is failing reliably. TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion) { - auto const& param = GetParam(); + SocketInetTestParam const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -136,51 +97,15 @@ TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion) { } } -INSTANTIATE_TEST_SUITE_P( - All, SocketInetLoopbackTest, - ::testing::Values( - // Listeners bound to IPv4 addresses refuse connections using IPv6 - // addresses. - TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()}, - TestParam{V4Any(), V4MappedAny()}, - TestParam{V4Any(), V4MappedLoopback()}, - TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()}, - TestParam{V4Loopback(), V4MappedLoopback()}, - TestParam{V4MappedAny(), V4Any()}, - TestParam{V4MappedAny(), V4Loopback()}, - TestParam{V4MappedAny(), V4MappedAny()}, - TestParam{V4MappedAny(), V4MappedLoopback()}, - TestParam{V4MappedLoopback(), V4Any()}, - TestParam{V4MappedLoopback(), V4Loopback()}, - TestParam{V4MappedLoopback(), V4MappedLoopback()}, - - // Listeners bound to IN6ADDR_ANY accept all connections. - TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()}, - TestParam{V6Any(), V4MappedAny()}, - TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()}, - TestParam{V6Any(), V6Loopback()}, - - // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 - // addresses. - TestParam{V6Loopback(), V6Any()}, - TestParam{V6Loopback(), V6Loopback()}), - DescribeTestParam); - -struct ProtocolTestParam { - std::string description; - int type; -}; - -std::string DescribeProtocolTestParam( - ::testing::TestParamInfo<ProtocolTestParam> const& info) { - return info.param.description; -} +INSTANTIATE_TEST_SUITE_P(All, SocketInetLoopbackTest, + SocketInetLoopbackTestValues(), + DescribeSocketInetTestParam); using SocketMultiProtocolInetLoopbackTest = ::testing::TestWithParam<ProtocolTestParam>; TEST_P(SocketMultiProtocolInetLoopbackTest, BindAvoidsListeningPortsReuseAddr) { - const auto& param = GetParam(); + ProtocolTestParam const& param = GetParam(); // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP // this is only permitted if there is no other listening socket. SKIP_IF(param.type != SOCK_STREAM); @@ -222,11 +147,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, BindAvoidsListeningPortsReuseAddr) { } } -INSTANTIATE_TEST_SUITE_P( - AllFamilies, SocketMultiProtocolInetLoopbackTest, - ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, - ProtocolTestParam{"UDP", SOCK_DGRAM}), - DescribeProtocolTestParam); +INSTANTIATE_TEST_SUITE_P(AllFamilies, SocketMultiProtocolInetLoopbackTest, + ProtocolTestValues(), DescribeProtocolTestParam); } // namespace diff --git a/test/syscalls/linux/socket_inet_loopback_test_params.h b/test/syscalls/linux/socket_inet_loopback_test_params.h new file mode 100644 index 000000000..42b48eb8a --- /dev/null +++ b/test/syscalls/linux/socket_inet_loopback_test_params.h @@ -0,0 +1,86 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_INET_LOOPBACK_TEST_PARAMS_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_INET_LOOPBACK_TEST_PARAMS_H_ + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +struct SocketInetTestParam { + TestAddress listener; + TestAddress connector; +}; + +inline std::string DescribeSocketInetTestParam( + ::testing::TestParamInfo<SocketInetTestParam> const& info) { + return absl::StrCat("Listen", info.param.listener.description, "_Connect", + info.param.connector.description); +} + +inline auto SocketInetLoopbackTestValues() { + return ::testing::Values( + // Listeners bound to IPv4 addresses refuse connections using IPv6 + // addresses. + SocketInetTestParam{V4Any(), V4Any()}, + SocketInetTestParam{V4Any(), V4Loopback()}, + SocketInetTestParam{V4Any(), V4MappedAny()}, + SocketInetTestParam{V4Any(), V4MappedLoopback()}, + SocketInetTestParam{V4Loopback(), V4Any()}, + SocketInetTestParam{V4Loopback(), V4Loopback()}, + SocketInetTestParam{V4Loopback(), V4MappedLoopback()}, + SocketInetTestParam{V4MappedAny(), V4Any()}, + SocketInetTestParam{V4MappedAny(), V4Loopback()}, + SocketInetTestParam{V4MappedAny(), V4MappedAny()}, + SocketInetTestParam{V4MappedAny(), V4MappedLoopback()}, + SocketInetTestParam{V4MappedLoopback(), V4Any()}, + SocketInetTestParam{V4MappedLoopback(), V4Loopback()}, + SocketInetTestParam{V4MappedLoopback(), V4MappedLoopback()}, + + // Listeners bound to IN6ADDR_ANY accept all connections. + SocketInetTestParam{V6Any(), V4Any()}, + SocketInetTestParam{V6Any(), V4Loopback()}, + SocketInetTestParam{V6Any(), V4MappedAny()}, + SocketInetTestParam{V6Any(), V4MappedLoopback()}, + SocketInetTestParam{V6Any(), V6Any()}, + SocketInetTestParam{V6Any(), V6Loopback()}, + + // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 + // addresses. + SocketInetTestParam{V6Loopback(), V6Any()}, + SocketInetTestParam{V6Loopback(), V6Loopback()}); +} + +struct ProtocolTestParam { + std::string description; + int type; +}; + +inline std::string DescribeProtocolTestParam( + ::testing::TestParamInfo<ProtocolTestParam> const& info) { + return info.param.description; +} + +inline auto ProtocolTestValues() { + return ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, + ProtocolTestParam{"UDP", SOCK_DGRAM}); +} + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_INET_LOOPBACK_TEST_PARAMS_H_ diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 83c33ec8d..5e36472b4 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -948,5 +948,124 @@ uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, return csum; } +PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { + switch (family) { + case AF_INET: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); + case AF_INET6: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { + switch (family) { + case AF_INET: + reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; + return NoError(); + case AF_INET6: + reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; + return NoError(); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +void SetupTimeWaitClose(const TestAddress* listener, + const TestAddress* connector, bool reuse, + bool accept_close, sockaddr_storage* listen_addr, + sockaddr_storage* conn_bound_addr) { + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener->family(), SOCK_STREAM, IPPROTO_TCP)); + if (reuse) { + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + } + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(listen_addr), listener->addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener->addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener->family(), *listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector->family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + + sockaddr_storage conn_addr = connector->addr; + ASSERT_NO_ERRNO(SetAddrPort(connector->family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), + connector->addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + socklen_t conn_addrlen = connector->addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), AsSockAddr(conn_bound_addr), &conn_addrlen), + SyscallSucceeds()); + + FileDescriptor active_closefd, passive_closefd; + if (accept_close) { + active_closefd = std::move(accepted); + passive_closefd = std::move(conn_fd); + } else { + active_closefd = std::move(conn_fd); + passive_closefd = std::move(accepted); + } + + // shutdown to trigger TIME_WAIT. + ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds()); + { + constexpr int kTimeout = 10000; + pollfd pfd = { + .fd = passive_closefd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + } + ASSERT_THAT(shutdown(passive_closefd.get(), SHUT_WR), SyscallSucceeds()); + { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + pollfd pfd = { + .fd = active_closefd.get(), + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } + + // This sleep is needed to reduce flake to ensure that the passive-close + // ensures the state transitions to CLOSE from LAST_ACK. + absl::SleepFor(absl::Seconds(1)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 76dc090e0..df4c26f26 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -564,6 +564,18 @@ inline sockaddr* AsSockAddr(sockaddr_un* s) { return reinterpret_cast<sockaddr*>(s); } +PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr); + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port); + +// setupTimeWaitClose sets up a socket endpoint in TIME_WAIT state. +// Callers can choose to perform active close on either ends of the connection +// and also specify if they want to enabled SO_REUSEADDR. +void SetupTimeWaitClose(const TestAddress* listener, + const TestAddress* connector, bool reuse, + bool accept_close, sockaddr_storage* listen_addr, + sockaddr_storage* conn_bound_addr); + namespace internal { PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, SocketType type, bool reuse_addr); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 5bfdecc79..183819faf 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1182,6 +1182,62 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSend) { EXPECT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceedsWithValue(0)); } +TEST_P(SimpleTcpSocketTest, SelfConnectSendShutdownWrite) { + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + const FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); + // Get the bound port. + ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); + + // Write enough data to fill send and receive buffers. + size_t write_size = 24 << 20; // 24 MiB. + std::vector<char> writebuf(write_size); + + ScopedThread t([&s]() { + absl::SleepFor(absl::Milliseconds(250)); + ASSERT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceeds()); + }); + + // Try to send the whole thing. + int n; + ASSERT_THAT(n = SendFd(s.get(), writebuf.data(), writebuf.size(), 0), + SyscallFailsWithErrno(EPIPE)); +} + +TEST_P(SimpleTcpSocketTest, SelfConnectRecvShutdownRead) { + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + const FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); + // Get the bound port. + ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); + + ScopedThread t([&s]() { + absl::SleepFor(absl::Milliseconds(250)); + ASSERT_THAT(shutdown(s.get(), SHUT_RD), SyscallSucceeds()); + }); + + char buf[1]; + EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallSucceedsWithValue(0)); +} + void NonBlockingConnect(int family, int16_t pollMask) { const FileDescriptor listener = ASSERT_NO_ERRNO_AND_VALUE(Socket(family, SOCK_STREAM, IPPROTO_TCP)); diff --git a/test/util/posix_error.h b/test/util/posix_error.h index 9ca09b77c..40853cb21 100644 --- a/test/util/posix_error.h +++ b/test/util/posix_error.h @@ -385,7 +385,7 @@ class PosixErrorIsMatcher { }; // Returns a gMock matcher that matches a PosixError or PosixErrorOr<> whose -// whose error code matches code_matcher, and whose error message matches +// error code matches code_matcher, and whose error message matches // message_matcher. template <typename ErrorCodeMatcher> PosixErrorIsMatcher PosixErrorIs( @@ -395,6 +395,14 @@ PosixErrorIsMatcher PosixErrorIs( std::move(message_matcher)); } +// Returns a gMock matcher that matches a PosixError or PosixErrorOr<> whose +// error code matches code_matcher. +template <typename ErrorCodeMatcher> +PosixErrorIsMatcher PosixErrorIs(ErrorCodeMatcher&& code_matcher) { + return PosixErrorIsMatcher(std::forward<ErrorCodeMatcher>(code_matcher), + ::testing::_); +} + // Returns a gMock matcher that matches a PosixErrorOr<> which is ok() and // value matches the inner matcher. template <typename InnerMatcher> |