diff options
Diffstat (limited to 'pkg')
542 files changed, 9764 insertions, 6932 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a461bb65e..eb004a7f6 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -15,12 +15,12 @@ go_library( "bpf.go", "capability.go", "clone.go", + "context.go", "dev.go", "elf.go", "epoll.go", "epoll_amd64.go", "epoll_arm64.go", - "errors.go", "errqueue.go", "eventfd.go", "exec.go", @@ -77,6 +77,8 @@ go_library( deps = [ "//pkg/abi", "//pkg/bits", + "//pkg/context", + "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", ], diff --git a/pkg/abi/linux/context.go b/pkg/abi/linux/context.go new file mode 100644 index 000000000..d2dbba183 --- /dev/null +++ b/pkg/abi/linux/context.go @@ -0,0 +1,36 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/context" +) + +// contextID is the linux package's type for context.Context.Value keys. +type contextID int + +const ( + // CtxSignalNoInfoFunc is a Context.Value key for a function to send signals. + CtxSignalNoInfoFunc contextID = iota +) + +// SignalNoInfoFuncFromContext returns a callback function that can be used to send a +// signal to the given context. +func SignalNoInfoFuncFromContext(ctx context.Context) func(Signal) error { + if f := ctx.Value(CtxSignalNoInfoFunc); f != nil { + return f.(func(Signal) error) + } + return nil +} diff --git a/pkg/abi/linux/errno/BUILD b/pkg/abi/linux/errno/BUILD new file mode 100644 index 000000000..d003342d5 --- /dev/null +++ b/pkg/abi/linux/errno/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "errno", + srcs = ["errno.go"], + visibility = ["//visibility:public"], +) diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errno/errno.go index b08b2687e..5a09c6605 100644 --- a/pkg/abi/linux/errors.go +++ b/pkg/abi/linux/errno/errno.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package linux +// Package errno holds errno codes for abi/linux. +package errno // Errno represents a Linux errno value. -type Errno int +type Errno uint32 // Errno values from include/uapi/asm-generic/errno-base.h. const ( @@ -60,8 +61,8 @@ const ( ENOLCK ENOSYS ENOTEMPTY - ELOOP //40 - _ // Skip for EWOULDBLOCK = EAGAIN + ELOOP // 40 + _ // Skip for EWOULDBLOCK = EAGAIN. ENOMSG //42 EIDRM ECHRNG @@ -78,13 +79,13 @@ const ( ENOANO EBADRQC EBADSLT - _ // Skip for EDEADLOCK = EDEADLK + _ // Skip for EDEADLOCK = EDEADLK. EBFONT ENOSTR // 60 ENODATA ETIME ENOSR - ENONET + _ // Skip for ENOENT = ENONET. ENOPKG EREMOTE ENOLINK @@ -160,4 +161,5 @@ const ( const ( EWOULDBLOCK = EAGAIN EDEADLOCK = EDEADLK + ENONET = ENOENT ) diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go index c59c9c136..ea4fdca0f 100644 --- a/pkg/abi/linux/ioctl_tun.go +++ b/pkg/abi/linux/ioctl_tun.go @@ -26,4 +26,7 @@ const ( IFF_TAP = 0x0002 IFF_NO_PI = 0x1000 IFF_NOFILTER = 0x1000 + + // According to linux/if_tun.h "This flag has no real effect" + IFF_ONE_QUEUE = 0x2000 ) diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index b088b207c..f8c0e891e 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -41,7 +41,6 @@ const ( // IP6T_ORIGINAL_DST is the ip6tables SOL_IPV6 socket option. Corresponds to // the value in include/uapi/linux/netfilter_ipv6/ip6_tables.h. -// TODO(gvisor.dev/issue/3549): Support IPv6 original destination. const IP6T_ORIGINAL_DST = 80 // IP6TReplace is the argument for the IP6T_SO_SET_REPLACE sockopt. It diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index 6ca57ffbb..06a4c6401 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -165,7 +166,7 @@ const ( SIG_IGN = 1 ) -// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h +// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h. const ( SA_NOCLDSTOP = 0x00000001 SA_NOCLDWAIT = 0x00000002 @@ -179,21 +180,17 @@ const ( SA_ONESHOT = SA_RESETHAND ) -// Signal info types. +// Signal stack flags for signalstack(2), from include/uapi/linux/signal.h. const ( - SI_MASK = 0xffff0000 - SI_KILL = 0 << 16 - SI_TIMER = 1 << 16 - SI_POLL = 2 << 16 - SI_FAULT = 3 << 16 - SI_CHLD = 4 << 16 - SI_RT = 5 << 16 - SI_MESGQ = 6 << 16 - SI_SYS = 7 << 16 + SS_ONSTACK = 1 + SS_DISABLE = 2 ) // SIGPOLL si_codes. const ( + // SI_POLL is defined as __SI_POLL in Linux 2.6. + SI_POLL = 2 << 16 + // POLL_IN indicates that data input available. POLL_IN = SI_POLL | 1 @@ -213,6 +210,75 @@ const ( POLL_HUP = SI_POLL | 6 ) +// Possible values for si_code. +const ( + // SI_USER is sent by kill, sigsend, raise. + SI_USER = 0 + + // SI_KERNEL is sent by the kernel from somewhere. + SI_KERNEL = 0x80 + + // SI_QUEUE is sent by sigqueue. + SI_QUEUE = -1 + + // SI_TIMER is sent by timer expiration. + SI_TIMER = -2 + + // SI_MESGQ is sent by real time mesq state change. + SI_MESGQ = -3 + + // SI_ASYNCIO is sent by AIO completion. + SI_ASYNCIO = -4 + + // SI_SIGIO is sent by queued SIGIO. + SI_SIGIO = -5 + + // SI_TKILL is sent by tkill system call. + SI_TKILL = -6 + + // SI_DETHREAD is sent by execve() killing subsidiary threads. + SI_DETHREAD = -7 + + // SI_ASYNCNL is sent by glibc async name lookup completion. + SI_ASYNCNL = -60 +) + +// CLD_* codes are only meaningful for SIGCHLD. +const ( + // CLD_EXITED indicates that a task exited. + CLD_EXITED = 1 + + // CLD_KILLED indicates that a task was killed by a signal. + CLD_KILLED = 2 + + // CLD_DUMPED indicates that a task was killed by a signal and then dumped + // core. + CLD_DUMPED = 3 + + // CLD_TRAPPED indicates that a task was stopped by ptrace. + CLD_TRAPPED = 4 + + // CLD_STOPPED indicates that a thread group completed a group stop. + CLD_STOPPED = 5 + + // CLD_CONTINUED indicates that a group-stopped thread group was continued. + CLD_CONTINUED = 6 +) + +// SYS_* codes are only meaningful for SIGSYS. +const ( + // SYS_SECCOMP indicates that a signal originates from seccomp. + SYS_SECCOMP = 1 +) + +// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify. +const ( + SIGEV_SIGNAL = 0 + SIGEV_NONE = 1 + SIGEV_THREAD = 2 + SIGEV_THREAD_ID = 4 +) + // Sigevent represents struct sigevent. // // +marshal @@ -227,10 +293,252 @@ type Sigevent struct { UnRemainder [44]byte } -// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify. -const ( - SIGEV_SIGNAL = 0 - SIGEV_NONE = 1 - SIGEV_THREAD = 2 - SIGEV_THREAD_ID = 4 -) +// SigAction represents struct sigaction. +// +// +marshal +// +stateify savable +type SigAction struct { + Handler uint64 + Flags uint64 + Restorer uint64 + Mask SignalSet +} + +// SignalStack represents information about a user stack, and is equivalent to +// stack_t. +// +// +marshal +// +stateify savable +type SignalStack struct { + Addr uint64 + Flags uint32 + _ uint32 + Size uint64 +} + +// Contains checks if the stack pointer is within this stack. +func (s *SignalStack) Contains(sp hostarch.Addr) bool { + return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size) +} + +// Top returns the stack's top address. +func (s *SignalStack) Top() hostarch.Addr { + return hostarch.Addr(s.Addr + s.Size) +} + +// IsEnabled returns true iff this signal stack is marked as enabled. +func (s *SignalStack) IsEnabled() bool { + return s.Flags&SS_DISABLE == 0 +} + +// SignalInfo represents information about a signal being delivered, and is +// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h). +// +// +marshal +// +stateify savable +type SignalInfo struct { + Signo int32 // Signal number + Errno int32 // Errno value + Code int32 // Signal code + _ uint32 + + // struct siginfo::_sifields is a union. In SignalInfo, fields in the union + // are accessed through methods. + // + // For reference, here is the definition of _sifields: (_sigfault._trapno, + // which does not exist on x86, omitted for clarity) + // + // union { + // int _pad[SI_PAD_SIZE]; + // + // /* kill() */ + // struct { + // __kernel_pid_t _pid; /* sender's pid */ + // __ARCH_SI_UID_T _uid; /* sender's uid */ + // } _kill; + // + // /* POSIX.1b timers */ + // struct { + // __kernel_timer_t _tid; /* timer id */ + // int _overrun; /* overrun count */ + // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)]; + // sigval_t _sigval; /* same as below */ + // int _sys_private; /* not to be passed to user */ + // } _timer; + // + // /* POSIX.1b signals */ + // struct { + // __kernel_pid_t _pid; /* sender's pid */ + // __ARCH_SI_UID_T _uid; /* sender's uid */ + // sigval_t _sigval; + // } _rt; + // + // /* SIGCHLD */ + // struct { + // __kernel_pid_t _pid; /* which child */ + // __ARCH_SI_UID_T _uid; /* sender's uid */ + // int _status; /* exit code */ + // __ARCH_SI_CLOCK_T _utime; + // __ARCH_SI_CLOCK_T _stime; + // } _sigchld; + // + // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */ + // struct { + // void *_addr; /* faulting insn/memory ref. */ + // short _addr_lsb; /* LSB of the reported address */ + // } _sigfault; + // + // /* SIGPOLL */ + // struct { + // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */ + // int _fd; + // } _sigpoll; + // + // /* SIGSYS */ + // struct { + // void *_call_addr; /* calling user insn */ + // int _syscall; /* triggering system call number */ + // unsigned int _arch; /* AUDIT_ARCH_* of syscall */ + // } _sigsys; + // } _sifields; + // + // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128 + // bytes. + Fields [128 - 16]byte +} + +// FixSignalCodeForUser fixes up si_code. +// +// The si_code we get from Linux may contain the kernel-specific code in the +// top 16 bits if it's positive (e.g., from ptrace). Linux's +// copy_siginfo_to_user does +// err |= __put_user((short)from->si_code, &to->si_code); +// to mask out those bits and we need to do the same. +func (s *SignalInfo) FixSignalCodeForUser() { + if s.Code > 0 { + s.Code &= 0x0000ffff + } +} + +// PID returns the si_pid field. +func (s *SignalInfo) PID() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4])) +} + +// SetPID mutates the si_pid field. +func (s *SignalInfo) SetPID(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) +} + +// UID returns the si_uid field. +func (s *SignalInfo) UID() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) +} + +// SetUID mutates the si_uid field. +func (s *SignalInfo) SetUID(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) +} + +// Sigval returns the sigval field, which is aliased to both si_int and si_ptr. +func (s *SignalInfo) Sigval() uint64 { + return hostarch.ByteOrder.Uint64(s.Fields[8:16]) +} + +// SetSigval mutates the sigval field. +func (s *SignalInfo) SetSigval(val uint64) { + hostarch.ByteOrder.PutUint64(s.Fields[8:16], val) +} + +// TimerID returns the si_timerid field. +func (s *SignalInfo) TimerID() TimerID { + return TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4])) +} + +// SetTimerID sets the si_timerid field. +func (s *SignalInfo) SetTimerID(val TimerID) { + hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) +} + +// Overrun returns the si_overrun field. +func (s *SignalInfo) Overrun() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) +} + +// SetOverrun sets the si_overrun field. +func (s *SignalInfo) SetOverrun(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) +} + +// Addr returns the si_addr field. +func (s *SignalInfo) Addr() uint64 { + return hostarch.ByteOrder.Uint64(s.Fields[0:8]) +} + +// SetAddr sets the si_addr field. +func (s *SignalInfo) SetAddr(val uint64) { + hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) +} + +// Status returns the si_status field. +func (s *SignalInfo) Status() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) +} + +// SetStatus mutates the si_status field. +func (s *SignalInfo) SetStatus(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) +} + +// CallAddr returns the si_call_addr field. +func (s *SignalInfo) CallAddr() uint64 { + return hostarch.ByteOrder.Uint64(s.Fields[0:8]) +} + +// SetCallAddr mutates the si_call_addr field. +func (s *SignalInfo) SetCallAddr(val uint64) { + hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) +} + +// Syscall returns the si_syscall field. +func (s *SignalInfo) Syscall() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) +} + +// SetSyscall mutates the si_syscall field. +func (s *SignalInfo) SetSyscall(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) +} + +// Arch returns the si_arch field. +func (s *SignalInfo) Arch() uint32 { + return hostarch.ByteOrder.Uint32(s.Fields[12:16]) +} + +// SetArch mutates the si_arch field. +func (s *SignalInfo) SetArch(val uint32) { + hostarch.ByteOrder.PutUint32(s.Fields[12:16], val) +} + +// Band returns the si_band field. +func (s *SignalInfo) Band() int64 { + return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8])) +} + +// SetBand mutates the si_band field. +func (s *SignalInfo) SetBand(val int64) { + // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. + // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is + // `int`. See siginfo.h. + hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) +} + +// FD returns the si_fd field. +func (s *SignalInfo) FD() uint32 { + return hostarch.ByteOrder.Uint32(s.Fields[8:12]) +} + +// SetFD mutates the si_fd field. +func (s *SignalInfo) SetFD(val uint32) { + hostarch.ByteOrder.PutUint32(s.Fields[8:12], val) +} diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go index 8ef837f27..1fa7a4f4f 100644 --- a/pkg/abi/linux/xattr.go +++ b/pkg/abi/linux/xattr.go @@ -23,6 +23,12 @@ const ( XATTR_CREATE = 1 XATTR_REPLACE = 2 + XATTR_SECURITY_PREFIX = "security." + XATTR_SECURITY_PREFIX_LEN = len(XATTR_SECURITY_PREFIX) + + XATTR_SYSTEM_PREFIX = "system." + XATTR_SYSTEM_PREFIX_LEN = len(XATTR_SYSTEM_PREFIX) + XATTR_TRUSTED_PREFIX = "trusted." XATTR_TRUSTED_PREFIX_LEN = len(XATTR_TRUSTED_PREFIX) diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go index 629dae8f4..889568177 100644 --- a/pkg/control/server/server.go +++ b/pkg/control/server/server.go @@ -22,6 +22,7 @@ package server import ( "os" + "time" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" @@ -65,13 +66,13 @@ func (s *Server) Wait() { // Stop stops the server. Note that this function should only be called once // and the server should not be used afterwards. -func (s *Server) Stop() { +func (s *Server) Stop(timeout time.Duration) { s.socket.Close() s.Wait() // This will cause existing clients to be terminated safely. If the // registered handlers have a Stop callback, it will be called. - s.server.Stop() + s.server.Stop(timeout) } // StartServing starts listening for connect and spawns the main service diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go index 74a55a123..514592b08 100644 --- a/pkg/crypto/crypto_stdlib.go +++ b/pkg/crypto/crypto_stdlib.go @@ -22,11 +22,11 @@ import ( // EcdsaVerify verifies the signature in r, s of hash using ECDSA and the // public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { - return ecdsa.Verify(pub, hash, r, s) +func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { + return ecdsa.Verify(pub, hash, r, s), nil } // SumSha384 returns the SHA384 checksum of the data. -func SumSha384(data []byte) (sum384 [sha512.Size384]byte) { - return sha512.Sum384(data) +func SumSha384(data []byte) ([sha512.Size384]byte, error) { + return sha512.Sum384(data), nil } diff --git a/pkg/errors/BUILD b/pkg/errors/BUILD new file mode 100644 index 000000000..36ea5da0c --- /dev/null +++ b/pkg/errors/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "errors", + srcs = ["errors.go"], + visibility = ["//:sandbox"], + deps = ["//pkg/abi/linux/errno"], +) diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go new file mode 100644 index 000000000..7984a770e --- /dev/null +++ b/pkg/errors/errors.go @@ -0,0 +1,40 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package errors holds the standardized error definition for gVisor. +package errors + +import ( + "gvisor.dev/gvisor/pkg/abi/linux/errno" +) + +// Error represents a syscall errno with a descriptive message. +type Error struct { + errno errno.Errno + message string +} + +// New creates a new *Error. +func New(err errno.Errno, message string) *Error { + return &Error{ + errno: err, + message: message, + } +} + +// Error implements error.Error. +func (e *Error) Error() string { return e.message } + +// Errno returns the underlying errno.Errno value. +func (e *Error) Errno() errno.Errno { return e.errno } diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD new file mode 100644 index 000000000..201727780 --- /dev/null +++ b/pkg/errors/linuxerr/BUILD @@ -0,0 +1,26 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "linuxerr", + srcs = ["linuxerr.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux/errno", + "//pkg/errors", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "linuxerr_test", + srcs = ["linuxerr_test.go"], + deps = [ + ":linuxerr", + "//pkg/abi/linux/errno", + "//pkg/errors", + "//pkg/syserror", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go new file mode 100644 index 000000000..9246f2e89 --- /dev/null +++ b/pkg/errors/linuxerr/linuxerr.go @@ -0,0 +1,347 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"),; +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package linuxerr contains syscall error codes exported as an error interface +// pointers. This allows for fast comparison and return operations comperable +// to unix.Errno constants. +package linuxerr + +import ( + "fmt" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors" +) + +const maxErrno uint32 = errno.EHWPOISON + 1 + +var ( + NOERROR = errors.New(errno.NOERRNO, "not an error") + EPERM = errors.New(errno.EPERM, "operation not permitted") + ENOENT = errors.New(errno.ENOENT, "no such file or directory") + ESRCH = errors.New(errno.ESRCH, "no such process") + EINTR = errors.New(errno.EINTR, "interrupted system call") + EIO = errors.New(errno.EIO, "I/O error") + ENXIO = errors.New(errno.ENXIO, "no such device or address") + E2BIG = errors.New(errno.E2BIG, "argument list too long") + ENOEXEC = errors.New(errno.ENOEXEC, "exec format error") + EBADF = errors.New(errno.EBADF, "bad file number") + ECHILD = errors.New(errno.ECHILD, "no child processes") + EAGAIN = errors.New(errno.EAGAIN, "try again") + ENOMEM = errors.New(errno.ENOMEM, "out of memory") + EACCES = errors.New(errno.EACCES, "permission denied") + EFAULT = errors.New(errno.EFAULT, "bad address") + ENOTBLK = errors.New(errno.ENOTBLK, "block device required") + EBUSY = errors.New(errno.EBUSY, "device or resource busy") + EEXIST = errors.New(errno.EEXIST, "file exists") + EXDEV = errors.New(errno.EXDEV, "cross-device link") + ENODEV = errors.New(errno.ENODEV, "no such device") + ENOTDIR = errors.New(errno.ENOTDIR, "not a directory") + EISDIR = errors.New(errno.EISDIR, "is a directory") + EINVAL = errors.New(errno.EINVAL, "invalid argument") + ENFILE = errors.New(errno.ENFILE, "file table overflow") + EMFILE = errors.New(errno.EMFILE, "too many open files") + ENOTTY = errors.New(errno.ENOTTY, "not a typewriter") + ETXTBSY = errors.New(errno.ETXTBSY, "text file busy") + EFBIG = errors.New(errno.EFBIG, "file too large") + ENOSPC = errors.New(errno.ENOSPC, "no space left on device") + ESPIPE = errors.New(errno.ESPIPE, "illegal seek") + EROFS = errors.New(errno.EROFS, "read-only file system") + EMLINK = errors.New(errno.EMLINK, "too many links") + EPIPE = errors.New(errno.EPIPE, "broken pipe") + EDOM = errors.New(errno.EDOM, "math argument out of domain of func") + ERANGE = errors.New(errno.ERANGE, "math result not representable") + + // Errno values from include/uapi/asm-generic/errno.h. + EDEADLK = errors.New(errno.EDEADLK, "resource deadlock would occur") + ENAMETOOLONG = errors.New(errno.ENAMETOOLONG, "file name too long") + ENOLCK = errors.New(errno.ENOLCK, "no record locks available") + ENOSYS = errors.New(errno.ENOSYS, "invalid system call number") + ENOTEMPTY = errors.New(errno.ENOTEMPTY, "directory not empty") + ELOOP = errors.New(errno.ELOOP, "too many symbolic links encountered") + ENOMSG = errors.New(errno.ENOMSG, "no message of desired type") + EIDRM = errors.New(errno.EIDRM, "identifier removed") + ECHRNG = errors.New(errno.ECHRNG, "channel number out of range") + EL2NSYNC = errors.New(errno.EL2NSYNC, "level 2 not synchronized") + EL3HLT = errors.New(errno.EL3HLT, "level 3 halted") + EL3RST = errors.New(errno.EL3RST, "level 3 reset") + ELNRNG = errors.New(errno.ELNRNG, "link number out of range") + EUNATCH = errors.New(errno.EUNATCH, "protocol driver not attached") + ENOCSI = errors.New(errno.ENOCSI, "no CSI structure available") + EL2HLT = errors.New(errno.EL2HLT, "level 2 halted") + EBADE = errors.New(errno.EBADE, "invalid exchange") + EBADR = errors.New(errno.EBADR, "invalid request descriptor") + EXFULL = errors.New(errno.EXFULL, "exchange full") + ENOANO = errors.New(errno.ENOANO, "no anode") + EBADRQC = errors.New(errno.EBADRQC, "invalid request code") + EBADSLT = errors.New(errno.EBADSLT, "invalid slot") + EBFONT = errors.New(errno.EBFONT, "bad font file format") + ENOSTR = errors.New(errno.ENOSTR, "device not a stream") + ENODATA = errors.New(errno.ENODATA, "no data available") + ETIME = errors.New(errno.ETIME, "timer expired") + ENOSR = errors.New(errno.ENOSR, "out of streams resources") + ENOPKG = errors.New(errno.ENOPKG, "package not installed") + EREMOTE = errors.New(errno.EREMOTE, "object is remote") + ENOLINK = errors.New(errno.ENOLINK, "link has been severed") + EADV = errors.New(errno.EADV, "advertise error") + ESRMNT = errors.New(errno.ESRMNT, "srmount error") + ECOMM = errors.New(errno.ECOMM, "communication error on send") + EPROTO = errors.New(errno.EPROTO, "protocol error") + EMULTIHOP = errors.New(errno.EMULTIHOP, "multihop attempted") + EDOTDOT = errors.New(errno.EDOTDOT, "RFS specific error") + EBADMSG = errors.New(errno.EBADMSG, "not a data message") + EOVERFLOW = errors.New(errno.EOVERFLOW, "value too large for defined data type") + ENOTUNIQ = errors.New(errno.ENOTUNIQ, "name not unique on network") + EBADFD = errors.New(errno.EBADFD, "file descriptor in bad state") + EREMCHG = errors.New(errno.EREMCHG, "remote address changed") + ELIBACC = errors.New(errno.ELIBACC, "can not access a needed shared library") + ELIBBAD = errors.New(errno.ELIBBAD, "accessing a corrupted shared library") + ELIBSCN = errors.New(errno.ELIBSCN, ".lib section in a.out corrupted") + ELIBMAX = errors.New(errno.ELIBMAX, "attempting to link in too many shared libraries") + ELIBEXEC = errors.New(errno.ELIBEXEC, "cannot exec a shared library directly") + EILSEQ = errors.New(errno.EILSEQ, "illegal byte sequence") + ERESTART = errors.New(errno.ERESTART, "interrupted system call should be restarted") + ESTRPIPE = errors.New(errno.ESTRPIPE, "streams pipe error") + EUSERS = errors.New(errno.EUSERS, "too many users") + ENOTSOCK = errors.New(errno.ENOTSOCK, "socket operation on non-socket") + EDESTADDRREQ = errors.New(errno.EDESTADDRREQ, "destination address required") + EMSGSIZE = errors.New(errno.EMSGSIZE, "message too long") + EPROTOTYPE = errors.New(errno.EPROTOTYPE, "protocol wrong type for socket") + ENOPROTOOPT = errors.New(errno.ENOPROTOOPT, "protocol not available") + EPROTONOSUPPORT = errors.New(errno.EPROTONOSUPPORT, "protocol not supported") + ESOCKTNOSUPPORT = errors.New(errno.ESOCKTNOSUPPORT, "socket type not supported") + EOPNOTSUPP = errors.New(errno.EOPNOTSUPP, "operation not supported on transport endpoint") + EPFNOSUPPORT = errors.New(errno.EPFNOSUPPORT, "protocol family not supported") + EAFNOSUPPORT = errors.New(errno.EAFNOSUPPORT, "address family not supported by protocol") + EADDRINUSE = errors.New(errno.EADDRINUSE, "address already in use") + EADDRNOTAVAIL = errors.New(errno.EADDRNOTAVAIL, "cannot assign requested address") + ENETDOWN = errors.New(errno.ENETDOWN, "network is down") + ENETUNREACH = errors.New(errno.ENETUNREACH, "network is unreachable") + ENETRESET = errors.New(errno.ENETRESET, "network dropped connection because of reset") + ECONNABORTED = errors.New(errno.ECONNABORTED, "software caused connection abort") + ECONNRESET = errors.New(errno.ECONNRESET, "connection reset by peer") + ENOBUFS = errors.New(errno.ENOBUFS, "no buffer space available") + EISCONN = errors.New(errno.EISCONN, "transport endpoint is already connected") + ENOTCONN = errors.New(errno.ENOTCONN, "transport endpoint is not connected") + ESHUTDOWN = errors.New(errno.ESHUTDOWN, "cannot send after transport endpoint shutdown") + ETOOMANYREFS = errors.New(errno.ETOOMANYREFS, "too many references: cannot splice") + ETIMEDOUT = errors.New(errno.ETIMEDOUT, "connection timed out") + ECONNREFUSED = errors.New(errno.ECONNREFUSED, "connection refused") + EHOSTDOWN = errors.New(errno.EHOSTDOWN, "host is down") + EHOSTUNREACH = errors.New(errno.EHOSTUNREACH, "no route to host") + EALREADY = errors.New(errno.EALREADY, "operation already in progress") + EINPROGRESS = errors.New(errno.EINPROGRESS, "operation now in progress") + ESTALE = errors.New(errno.ESTALE, "stale file handle") + EUCLEAN = errors.New(errno.EUCLEAN, "structure needs cleaning") + ENOTNAM = errors.New(errno.ENOTNAM, "not a XENIX named type file") + ENAVAIL = errors.New(errno.ENAVAIL, "no XENIX semaphores available") + EISNAM = errors.New(errno.EISNAM, "is a named type file") + EREMOTEIO = errors.New(errno.EREMOTEIO, "remote I/O error") + EDQUOT = errors.New(errno.EDQUOT, "quota exceeded") + ENOMEDIUM = errors.New(errno.ENOMEDIUM, "no medium found") + EMEDIUMTYPE = errors.New(errno.EMEDIUMTYPE, "wrong medium type") + ECANCELED = errors.New(errno.ECANCELED, "operation Canceled") + ENOKEY = errors.New(errno.ENOKEY, "required key not available") + EKEYEXPIRED = errors.New(errno.EKEYEXPIRED, "key has expired") + EKEYREVOKED = errors.New(errno.EKEYREVOKED, "key has been revoked") + EKEYREJECTED = errors.New(errno.EKEYREJECTED, "key was rejected by service") + EOWNERDEAD = errors.New(errno.EOWNERDEAD, "owner died") + ENOTRECOVERABLE = errors.New(errno.ENOTRECOVERABLE, "state not recoverable") + ERFKILL = errors.New(errno.ERFKILL, "operation not possible due to RF-kill") + EHWPOISON = errors.New(errno.EHWPOISON, "memory page has hardware error") + + // Errors equivalent to other errors. + EWOULDBLOCK = EAGAIN + EDEADLOCK = EDEADLK + ENONET = ENOENT +) + +// A nil *errors.Error denotes no error and is placed at the 0 index of +// errorSlice. Thus, any other empty index should not be nil or a valid error. +// This marks that index as an invalid error so any comparison to nil or a +// valid linuxerr fails. +var errNotValidError = errors.New(errno.Errno(maxErrno), "not a valid error") + +// The following errorSlice holds errors by errno for fast translation between +// errnos (especially uint32(sycall.Errno)) and *Error. +var errorSlice = []*errors.Error{ + // Errno values from include/uapi/asm-generic/errno-base.h. + errno.NOERRNO: NOERROR, + errno.EPERM: EPERM, + errno.ENOENT: ENOENT, + errno.ESRCH: ESRCH, + errno.EINTR: EINTR, + errno.EIO: EIO, + errno.ENXIO: ENXIO, + errno.E2BIG: E2BIG, + errno.ENOEXEC: ENOEXEC, + errno.EBADF: EBADF, + errno.ECHILD: ECHILD, + errno.EAGAIN: EAGAIN, + errno.ENOMEM: ENOMEM, + errno.EACCES: EACCES, + errno.EFAULT: EFAULT, + errno.ENOTBLK: ENOTBLK, + errno.EBUSY: EBUSY, + errno.EEXIST: EEXIST, + errno.EXDEV: EXDEV, + errno.ENODEV: ENODEV, + errno.ENOTDIR: ENOTDIR, + errno.EISDIR: EISDIR, + errno.EINVAL: EINVAL, + errno.ENFILE: ENFILE, + errno.EMFILE: EMFILE, + errno.ENOTTY: ENOTTY, + errno.ETXTBSY: ETXTBSY, + errno.EFBIG: EFBIG, + errno.ENOSPC: ENOSPC, + errno.ESPIPE: ESPIPE, + errno.EROFS: EROFS, + errno.EMLINK: EMLINK, + errno.EPIPE: EPIPE, + errno.EDOM: EDOM, + errno.ERANGE: ERANGE, + + // Errno values from include/uapi/asm-generic/errno.h. + errno.EDEADLK: EDEADLK, + errno.ENAMETOOLONG: ENAMETOOLONG, + errno.ENOLCK: ENOLCK, + errno.ENOSYS: ENOSYS, + errno.ENOTEMPTY: ENOTEMPTY, + errno.ELOOP: ELOOP, + errno.ELOOP + 1: errNotValidError, // No valid errno between ELOOP and ENOMSG. + errno.ENOMSG: ENOMSG, + errno.EIDRM: EIDRM, + errno.ECHRNG: ECHRNG, + errno.EL2NSYNC: EL2NSYNC, + errno.EL3HLT: EL3HLT, + errno.EL3RST: EL3RST, + errno.ELNRNG: ELNRNG, + errno.EUNATCH: EUNATCH, + errno.ENOCSI: ENOCSI, + errno.EL2HLT: EL2HLT, + errno.EBADE: EBADE, + errno.EBADR: EBADR, + errno.EXFULL: EXFULL, + errno.ENOANO: ENOANO, + errno.EBADRQC: EBADRQC, + errno.EBADSLT: EBADSLT, + errno.EBADSLT + 1: errNotValidError, // No valid errno between EBADSLT and ENOPKG. + errno.EBFONT: EBFONT, + errno.ENOSTR: ENOSTR, + errno.ENODATA: ENODATA, + errno.ETIME: ETIME, + errno.ENOSR: ENOSR, + errno.ENOSR + 1: errNotValidError, // No valid errno betweeen ENOSR and ENOPKG. + errno.ENOPKG: ENOPKG, + errno.EREMOTE: EREMOTE, + errno.ENOLINK: ENOLINK, + errno.EADV: EADV, + errno.ESRMNT: ESRMNT, + errno.ECOMM: ECOMM, + errno.EPROTO: EPROTO, + errno.EMULTIHOP: EMULTIHOP, + errno.EDOTDOT: EDOTDOT, + errno.EBADMSG: EBADMSG, + errno.EOVERFLOW: EOVERFLOW, + errno.ENOTUNIQ: ENOTUNIQ, + errno.EBADFD: EBADFD, + errno.EREMCHG: EREMCHG, + errno.ELIBACC: ELIBACC, + errno.ELIBBAD: ELIBBAD, + errno.ELIBSCN: ELIBSCN, + errno.ELIBMAX: ELIBMAX, + errno.ELIBEXEC: ELIBEXEC, + errno.EILSEQ: EILSEQ, + errno.ERESTART: ERESTART, + errno.ESTRPIPE: ESTRPIPE, + errno.EUSERS: EUSERS, + errno.ENOTSOCK: ENOTSOCK, + errno.EDESTADDRREQ: EDESTADDRREQ, + errno.EMSGSIZE: EMSGSIZE, + errno.EPROTOTYPE: EPROTOTYPE, + errno.ENOPROTOOPT: ENOPROTOOPT, + errno.EPROTONOSUPPORT: EPROTONOSUPPORT, + errno.ESOCKTNOSUPPORT: ESOCKTNOSUPPORT, + errno.EOPNOTSUPP: EOPNOTSUPP, + errno.EPFNOSUPPORT: EPFNOSUPPORT, + errno.EAFNOSUPPORT: EAFNOSUPPORT, + errno.EADDRINUSE: EADDRINUSE, + errno.EADDRNOTAVAIL: EADDRNOTAVAIL, + errno.ENETDOWN: ENETDOWN, + errno.ENETUNREACH: ENETUNREACH, + errno.ENETRESET: ENETRESET, + errno.ECONNABORTED: ECONNABORTED, + errno.ECONNRESET: ECONNRESET, + errno.ENOBUFS: ENOBUFS, + errno.EISCONN: EISCONN, + errno.ENOTCONN: ENOTCONN, + errno.ESHUTDOWN: ESHUTDOWN, + errno.ETOOMANYREFS: ETOOMANYREFS, + errno.ETIMEDOUT: ETIMEDOUT, + errno.ECONNREFUSED: ECONNREFUSED, + errno.EHOSTDOWN: EHOSTDOWN, + errno.EHOSTUNREACH: EHOSTUNREACH, + errno.EALREADY: EALREADY, + errno.EINPROGRESS: EINPROGRESS, + errno.ESTALE: ESTALE, + errno.EUCLEAN: EUCLEAN, + errno.ENOTNAM: ENOTNAM, + errno.ENAVAIL: ENAVAIL, + errno.EISNAM: EISNAM, + errno.EREMOTEIO: EREMOTEIO, + errno.EDQUOT: EDQUOT, + errno.ENOMEDIUM: ENOMEDIUM, + errno.EMEDIUMTYPE: EMEDIUMTYPE, + errno.ECANCELED: ECANCELED, + errno.ENOKEY: ENOKEY, + errno.EKEYEXPIRED: EKEYEXPIRED, + errno.EKEYREVOKED: EKEYREVOKED, + errno.EKEYREJECTED: EKEYREJECTED, + errno.EOWNERDEAD: EOWNERDEAD, + errno.ENOTRECOVERABLE: ENOTRECOVERABLE, + errno.ERFKILL: ERFKILL, + errno.EHWPOISON: EHWPOISON, +} + +// ErrorFromErrno gets an error from the list and panics if an invalid entry is requested. +func ErrorFromErrno(e errno.Errno) *errors.Error { + err := errorSlice[e] + // Done this way because a single comparison in benchmarks is 2-3 faster + // than something like ( if err == nil && err > 0 ). + if err != errNotValidError { + return err + } + panic(fmt.Sprintf("invalid error requested with errno: %d", e)) +} + +// Equals compars a linuxerr to a given error +// TODO(b/34162363): Remove when syserror is removed. +func Equals(e *errors.Error, err error) bool { + if err == nil { + return e == NOERROR || e == nil + } + if e == nil { + return err == NOERROR || err == unix.Errno(0) + } + + switch err.(type) { + case *errors.Error: + return e == err + case unix.Errno, error: + return unix.Errno(e.Errno()) == err + } + return false +} diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go new file mode 100644 index 000000000..b1250d24f --- /dev/null +++ b/pkg/errors/linuxerr/linuxerr_test.go @@ -0,0 +1,306 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syserror_test + +import ( + "errors" + "io" + "io/fs" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + gErrors "gvisor.dev/gvisor/pkg/errors" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/syserror" +) + +var globalError error + +func BenchmarkAssignUnix(b *testing.B) { + for i := b.N; i > 0; i-- { + globalError = unix.EINVAL + } +} + +func BenchmarkAssignLinuxerr(b *testing.B) { + for i := b.N; i > 0; i-- { + globalError = linuxerr.EINVAL + } +} + +func BenchmarkAssignSyserror(b *testing.B) { + for i := b.N; i > 0; i-- { + globalError = linuxerr.ENOMSG + } +} + +func BenchmarkCompareUnix(b *testing.B) { + globalError = unix.EAGAIN + j := 0 + for i := b.N; i > 0; i-- { + if globalError == unix.EINVAL { + j++ + } + } +} + +func BenchmarkCompareLinuxerr(b *testing.B) { + globalError = linuxerr.E2BIG + j := 0 + for i := b.N; i > 0; i-- { + if globalError == linuxerr.EINVAL { + j++ + } + } +} + +func BenchmarkCompareSyserror(b *testing.B) { + globalError = syserror.EAGAIN + j := 0 + for i := b.N; i > 0; i-- { + if globalError == syserror.EACCES { + j++ + } + } +} + +func BenchmarkSwitchUnix(b *testing.B) { + globalError = unix.EPERM + j := 0 + for i := b.N; i > 0; i-- { + switch globalError { + case unix.EINVAL: + j++ + case unix.EINTR: + j += 2 + case unix.EAGAIN: + j += 3 + } + } +} + +func BenchmarkSwitchLinuxerr(b *testing.B) { + globalError = linuxerr.EPERM + j := 0 + for i := b.N; i > 0; i-- { + switch globalError { + case linuxerr.EINVAL: + j++ + case linuxerr.EINTR: + j += 2 + case linuxerr.EAGAIN: + j += 3 + } + } +} + +func BenchmarkSwitchSyserror(b *testing.B) { + globalError = syserror.EPERM + j := 0 + for i := b.N; i > 0; i-- { + switch globalError { + case syserror.EACCES: + j++ + case syserror.EINTR: + j += 2 + case syserror.EAGAIN: + j += 3 + } + } +} + +func BenchmarkReturnUnix(b *testing.B) { + var localError error + f := func() error { + return unix.EINVAL + } + for i := b.N; i > 0; i-- { + localError = f() + } + if localError != nil { + return + } +} + +func BenchmarkReturnLinuxerr(b *testing.B) { + var localError error + f := func() error { + return linuxerr.EINVAL + } + for i := b.N; i > 0; i-- { + localError = f() + } + if localError != nil { + return + } +} + +func BenchmarkConvertUnixLinuxerr(b *testing.B) { + var localError error + for i := b.N; i > 0; i-- { + localError = linuxerr.ErrorFromErrno(errno.Errno(unix.EINVAL)) + } + if localError != nil { + return + } +} + +func BenchmarkConvertUnixLinuxerrZero(b *testing.B) { + var localError error + for i := b.N; i > 0; i-- { + localError = linuxerr.ErrorFromErrno(errno.Errno(0)) + } + if localError != nil { + return + } +} + +type translationTestTable struct { + fn string + errIn error + syscallErrorIn unix.Errno + expectedBool bool + expectedTranslation unix.Errno +} + +func TestErrorTranslation(t *testing.T) { + myError := errors.New("My test error") + myError2 := errors.New("Another test error") + testTable := []translationTestTable{ + {"TranslateError", myError, 0, false, 0}, + {"TranslateError", myError2, 0, false, 0}, + {"AddErrorTranslation", myError, unix.EAGAIN, true, 0}, + {"AddErrorTranslation", myError, unix.EAGAIN, false, 0}, + {"AddErrorTranslation", myError, unix.EPERM, false, 0}, + {"TranslateError", myError, 0, true, unix.EAGAIN}, + {"TranslateError", myError2, 0, false, 0}, + {"AddErrorTranslation", myError2, unix.EPERM, true, 0}, + {"AddErrorTranslation", myError2, unix.EPERM, false, 0}, + {"AddErrorTranslation", myError2, unix.EAGAIN, false, 0}, + {"TranslateError", myError, 0, true, unix.EAGAIN}, + {"TranslateError", myError2, 0, true, unix.EPERM}, + } + for _, tt := range testTable { + switch tt.fn { + case "TranslateError": + err, ok := syserror.TranslateError(tt.errIn) + if ok != tt.expectedBool { + t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) + } else if err != tt.expectedTranslation { + t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation) + } + case "AddErrorTranslation": + ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn) + if ok != tt.expectedBool { + t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) + } + default: + t.Fatalf("Unknown function %v", tt.fn) + } + } +} + +func TestSyscallErrnoToErrors(t *testing.T) { + for _, tc := range []struct { + errno syscall.Errno + err *gErrors.Error + }{ + {errno: syscall.EACCES, err: linuxerr.EACCES}, + {errno: syscall.EAGAIN, err: linuxerr.EAGAIN}, + {errno: syscall.EBADF, err: linuxerr.EBADF}, + {errno: syscall.EBUSY, err: linuxerr.EBUSY}, + {errno: syscall.EDOM, err: linuxerr.EDOM}, + {errno: syscall.EEXIST, err: linuxerr.EEXIST}, + {errno: syscall.EFAULT, err: linuxerr.EFAULT}, + {errno: syscall.EFBIG, err: linuxerr.EFBIG}, + {errno: syscall.EINTR, err: linuxerr.EINTR}, + {errno: syscall.EINVAL, err: linuxerr.EINVAL}, + {errno: syscall.EIO, err: linuxerr.EIO}, + {errno: syscall.ENOTDIR, err: linuxerr.ENOTDIR}, + {errno: syscall.ENOTTY, err: linuxerr.ENOTTY}, + {errno: syscall.EPERM, err: linuxerr.EPERM}, + {errno: syscall.EPIPE, err: linuxerr.EPIPE}, + {errno: syscall.ESPIPE, err: linuxerr.ESPIPE}, + {errno: syscall.EWOULDBLOCK, err: linuxerr.EAGAIN}, + } { + t.Run(tc.errno.Error(), func(t *testing.T) { + e := linuxerr.ErrorFromErrno(errno.Errno(tc.errno)) + if e != tc.err { + t.Fatalf("Mismatch errors: want: %+v (%d) got: %+v %d", tc.err, tc.err.Errno(), e, e.Errno()) + } + }) + } +} + +// TestEqualsMethod tests that the Equals method correctly compares syerror, +// unix.Errno and linuxerr. +// TODO (b/34162363): Remove this. +func TestEqualsMethod(t *testing.T) { + for _, tc := range []struct { + name string + linuxErr []*gErrors.Error + err []error + equal bool + }{ + { + name: "compare nil", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, + err: []error{nil, linuxerr.NOERROR, unix.Errno(0)}, + equal: true, + }, + { + name: "linuxerr nil error not", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, + err: []error{unix.Errno(1), linuxerr.EPERM, syserror.EACCES}, + equal: false, + }, + { + name: "linuxerr not nil error nil", + linuxErr: []*gErrors.Error{linuxerr.ENOENT}, + err: []error{nil, unix.Errno(0), linuxerr.NOERROR}, + equal: false, + }, + { + name: "equal errors", + linuxErr: []*gErrors.Error{linuxerr.ESRCH}, + err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, + equal: true, + }, + { + name: "unequal errors", + linuxErr: []*gErrors.Error{linuxerr.ENOENT}, + err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, + equal: false, + }, + { + name: "other error", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR, linuxerr.E2BIG, linuxerr.EINVAL}, + err: []error{fs.ErrInvalid, io.EOF}, + equal: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + for _, le := range tc.linuxErr { + for _, e := range tc.err { + if linuxerr.Equals(le, e) != tc.equal { + t.Fatalf("Expected %t from Equals method for linuxerr: %s %T and error: %s %T", tc.equal, le, le, e, e) + } + } + } + }) + } +} diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 9730b88c1..c810c7946 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -10,9 +10,7 @@ go_library( "flipcall_unsafe.go", "futex_linux.go", "io.go", - "packet_window_allocator.go", - "packet_window_mmap_amd64.go", - "packet_window_mmap_arm64.go", + "packet_window.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 8d8309a73..f0e4ff487 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -22,6 +22,7 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/memutil" ) // An Endpoint provides the ability to synchronously transfer data and control @@ -96,9 +97,9 @@ func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ... if pwd.Length > math.MaxUint32 { return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32) } - m, e := packetWindowMmap(pwd) - if e != 0 { - return fmt.Errorf("failed to mmap packet window: %v", e) + m, err := memutil.MapFile(0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) + if err != nil { + return fmt.Errorf("failed to mmap packet window: %v", err) } ep.packet = m ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes) diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window.go index 9122c97b7..9122c97b7 100644 --- a/pkg/flipcall/packet_window_allocator.go +++ b/pkg/flipcall/packet_window.go diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD deleted file mode 100644 index f4e9a6af9..000000000 --- a/pkg/iovec/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "iovec", - srcs = ["iovec.go"], - visibility = ["//:sandbox"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) - -go_test( - name = "iovec_test", - size = "small", - srcs = ["iovec_test.go"], - library = ":iovec", - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go deleted file mode 100644 index a281c05b6..000000000 --- a/pkg/iovec/iovec.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -// Package iovec provides helpers to interact with vectorized I/O on host -// system. -package iovec - -import ( - "golang.org/x/sys/unix" -) - -// MaxIovs is the maximum number of iovecs host platform can accept. -var MaxIovs = 1024 - -// Builder is a builder for slice of unix.Iovec. -type Builder struct { - iovec []unix.Iovec - storage [8]unix.Iovec - - // overflow tracks the last buffer when iovec length is at MaxIovs. - overflow []byte -} - -// Add adds buf to b preparing to be written. Zero-length buf won't be added. -func (b *Builder) Add(buf []byte) { - if len(buf) == 0 { - return - } - if b.iovec == nil { - b.iovec = b.storage[:0] - } - if len(b.iovec) >= MaxIovs { - b.addByAppend(buf) - return - } - - b.iovec = append(b.iovec, unix.Iovec{Base: &buf[0]}) - b.iovec[len(b.iovec)-1].SetLen(len(buf)) - - // Keep the last buf if iovec is at max capacity. We will need to append to it - // for later bufs. - if len(b.iovec) == MaxIovs { - n := len(buf) - b.overflow = buf[:n:n] - } -} - -func (b *Builder) addByAppend(buf []byte) { - b.overflow = append(b.overflow, buf...) - b.iovec[len(b.iovec)-1] = unix.Iovec{Base: &b.overflow[0]} - b.iovec[len(b.iovec)-1].SetLen(len(b.overflow)) -} - -// Build returns the final Iovec slice. The length of returned iovec will not -// excceed MaxIovs. -func (b *Builder) Build() []unix.Iovec { - return b.iovec -} diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go deleted file mode 100644 index f6deb4208..000000000 --- a/pkg/iovec/iovec_test.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package iovec - -import ( - "bytes" - "fmt" - "testing" - "unsafe" - - "golang.org/x/sys/unix" -) - -func TestBuilderEmpty(t *testing.T) { - var builder Builder - iovecs := builder.Build() - if got, want := len(iovecs), 0; got != want { - t.Errorf("len(iovecs) = %d, want %d", got, want) - } -} - -func TestBuilderBuild(t *testing.T) { - a := []byte{1, 2} - b := []byte{3, 4, 5} - - var builder Builder - builder.Add(a) - builder.Add(b) - builder.Add(nil) // Nil slice won't be added. - builder.Add([]byte{}) // Empty slice won't be added. - iovecs := builder.Build() - - if got, want := len(iovecs), 2; got != want { - t.Fatalf("len(iovecs) = %d, want %d", got, want) - } - for i, data := range [][]byte{a, b} { - if got, want := *iovecs[i].Base, data[0]; got != want { - t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want) - } - if got, want := iovecs[i].Len, uint64(len(data)); got != want { - t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want) - } - } -} - -func TestBuilderBuildMaxIov(t *testing.T) { - for _, test := range []struct { - numIov int - }{ - { - numIov: MaxIovs - 1, - }, - { - numIov: MaxIovs, - }, - { - numIov: MaxIovs + 1, - }, - { - numIov: MaxIovs + 10, - }, - } { - name := fmt.Sprintf("numIov=%v", test.numIov) - t.Run(name, func(t *testing.T) { - var data []byte - var builder Builder - for i := 0; i < test.numIov; i++ { - buf := []byte{byte(i)} - builder.Add(buf) - data = append(data, buf...) - } - iovec := builder.Build() - - // Check the expected length of iovec. - wantNum := test.numIov - if wantNum > MaxIovs { - wantNum = MaxIovs - } - if got, want := len(iovec), wantNum; got != want { - t.Errorf("len(iovec) = %d, want %d", got, want) - } - - // Test a real read-write. - var fds [2]int - if err := unix.Pipe(fds[:]); err != nil { - t.Fatalf("Pipe: %v", err) - } - defer unix.Close(fds[0]) - defer unix.Close(fds[1]) - - wrote, _, e := unix.RawSyscall(unix.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) - if int(wrote) != len(data) || e != 0 { - t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data)) - } - - got := make([]byte, len(data)) - if n, err := unix.Read(fds[0], got); n != len(got) || err != nil { - t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got)) - } - - if !bytes.Equal(got, data) { - t.Errorf("read: got data %v, want %v", got, data) - } - }) - } -} diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD index 190b57c29..6e5ce136d 100644 --- a/pkg/marshal/primitive/BUILD +++ b/pkg/marshal/primitive/BUILD @@ -12,9 +12,7 @@ go_library( "//:sandbox", ], deps = [ - "//pkg/context", "//pkg/hostarch", "//pkg/marshal", - "//pkg/usermem", ], ) diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 6f38992b7..1c49cf082 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -19,10 +19,8 @@ package primitive import ( "io" - "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" - "gvisor.dev/gvisor/pkg/usermem" ) // Int8 is a marshal.Marshallable implementation for int8. @@ -400,26 +398,3 @@ func CopyStringOut(cc marshal.CopyContext, addr hostarch.Addr, src string) (int, srcP := ByteSlice(src) return srcP.CopyOut(cc, addr) } - -// IOCopyContext wraps an object implementing hostarch.IO to implement -// marshal.CopyContext. -type IOCopyContext struct { - Ctx context.Context - IO usermem.IO - Opts usermem.IOOpts -} - -// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. -func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { - return make([]byte, size) -} - -// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. -func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) { - return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) -} - -// CopyInBytes implements marshal.CopyContext.CopyInBytes. -func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) { - return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) -} diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD index 9d07d98b4..bea595286 100644 --- a/pkg/memutil/BUILD +++ b/pkg/memutil/BUILD @@ -4,7 +4,11 @@ package(licenses = ["notice"]) go_library( name = "memutil", - srcs = ["memutil_unsafe.go"], + srcs = [ + "memfd_linux_unsafe.go", + "memutil.go", + "mmap.go", + ], visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) diff --git a/pkg/memutil/memutil_unsafe.go b/pkg/memutil/memfd_linux_unsafe.go index 6676d1ce3..504382213 100644 --- a/pkg/memutil/memutil_unsafe.go +++ b/pkg/memutil/memfd_linux_unsafe.go @@ -14,7 +14,6 @@ // +build linux -// Package memutil provides a wrapper for the memfd_create() system call. package memutil import ( diff --git a/pkg/tcpip/time.s b/pkg/memutil/memutil.go index fb37360ac..3185882fd 100644 --- a/pkg/tcpip/time.s +++ b/pkg/memutil/memutil.go @@ -12,4 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Empty assembly file so empty func definitions work. +// Package memutil provides utilities for working with shared memory files. +package memutil diff --git a/pkg/flipcall/packet_window_mmap_amd64.go b/pkg/memutil/mmap.go index ced587a2a..7c939293f 100644 --- a/pkg/flipcall/packet_window_mmap_amd64.go +++ b/pkg/memutil/mmap.go @@ -12,12 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package flipcall +package memutil -import "golang.org/x/sys/unix" +import ( + "golang.org/x/sys/unix" +) -// Return a memory mapping of the pwd in memory that can be shared outside the sandbox. -func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, unix.Errno) { - m, _, err := unix.RawSyscall6(unix.SYS_MMAP, 0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) - return m, err +// MapFile returns a memory mapping configured by the given options as per +// mmap(2). +func MapFile(addr, len, prot, flags, fd, offset uintptr) (uintptr, error) { + m, _, e := unix.RawSyscall6(unix.SYS_MMAP, addr, len, prot, flags, fd, offset) + if e != 0 { + return 0, e + } + return m, nil } diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index ac7868ad9..0b961d3d9 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -151,21 +151,21 @@ type VerityDescriptor struct { Mode uint32 UID uint32 GID uint32 - Children map[string]struct{} + Children []string SymlinkTarget string RootHash []byte } -func (d *VerityDescriptor) String() string { +func (d *VerityDescriptor) encode() []byte { b := new(bytes.Buffer) e := gob.NewEncoder(b) - e.Encode(d.Children) - return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, Children: %v, Symlink: %s, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, b.Bytes(), d.SymlinkTarget, d.RootHash) + e.Encode(d) + return b.Bytes() } // verify generates a hash from d, and compares it with expected. func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error { - h, err := hashData([]byte(d.String()), hashAlgorithms) + h, err := hashData(d.encode(), hashAlgorithms) if err != nil { return err } @@ -210,7 +210,7 @@ type GenerateParams struct { GID uint32 // Children is a map of children names for a directory. It should be // empty for a regular file. - Children map[string]struct{} + Children []string // SymlinkTarget is the target path of a symlink file, or "" if the file is not a symlink. SymlinkTarget string // HashAlgorithms is the algorithms used to hash data. @@ -242,7 +242,7 @@ func Generate(params *GenerateParams) ([]byte, error) { // If file is a symlink do not generate root hash for file content. if params.SymlinkTarget != "" { - return hashData([]byte(descriptor.String()), params.HashAlgorithms) + return hashData(descriptor.encode(), params.HashAlgorithms) } layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile) @@ -315,7 +315,7 @@ func Generate(params *GenerateParams) ([]byte, error) { numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock() } descriptor.RootHash = root - return hashData([]byte(descriptor.String()), params.HashAlgorithms) + return hashData(descriptor.encode(), params.HashAlgorithms) } // VerifyParams contains the params used to verify a portion of a file against @@ -339,7 +339,7 @@ type VerifyParams struct { GID uint32 // Children is a map of children names for a directory. It should be // empty for a regular file. - Children map[string]struct{} + Children []string // SymlinkTarget is the target path of a symlink file, or "" if the file is not a symlink. SymlinkTarget string // HashAlgorithms is the algorithms used to hash data. diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index 5d6f8df1b..1447fd139 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -206,112 +206,112 @@ func TestGenerate(t *testing.T) { data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, - expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27}, + expectedHash: []byte{78, 38, 225, 107, 61, 246, 26, 6, 71, 163, 254, 97, 112, 200, 87, 232, 190, 87, 231, 160, 119, 124, 61, 229, 49, 126, 90, 223, 134, 51, 77, 182}, }, { name: "OnePageZeroesSHA256SameFile", data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, - expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27}, + expectedHash: []byte{78, 38, 225, 107, 61, 246, 26, 6, 71, 163, 254, 97, 112, 200, 87, 232, 190, 87, 231, 160, 119, 124, 61, 229, 49, 126, 90, 223, 134, 51, 77, 182}, }, { name: "OnePageZeroesSHA512SeparateFile", data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, - expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14}, + expectedHash: []byte{221, 45, 182, 132, 61, 212, 227, 145, 150, 131, 98, 221, 195, 5, 89, 21, 188, 36, 250, 101, 85, 78, 197, 253, 193, 23, 74, 219, 28, 108, 77, 47, 65, 79, 123, 144, 50, 245, 109, 72, 71, 80, 24, 77, 158, 95, 242, 185, 109, 163, 105, 183, 67, 106, 55, 194, 223, 46, 12, 242, 165, 203, 172, 254}, }, { name: "OnePageZeroesSHA512SameFile", data: bytes.Repeat([]byte{0}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, - expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14}, + expectedHash: []byte{221, 45, 182, 132, 61, 212, 227, 145, 150, 131, 98, 221, 195, 5, 89, 21, 188, 36, 250, 101, 85, 78, 197, 253, 193, 23, 74, 219, 28, 108, 77, 47, 65, 79, 123, 144, 50, 245, 109, 72, 71, 80, 24, 77, 158, 95, 242, 185, 109, 163, 105, 183, 67, 106, 55, 194, 223, 46, 12, 242, 165, 203, 172, 254}, }, { name: "MultiplePageZeroesSHA256SeparateFile", data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, - expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185}, + expectedHash: []byte{131, 122, 73, 143, 4, 202, 193, 156, 218, 169, 196, 223, 70, 100, 117, 191, 241, 113, 134, 11, 229, 231, 105, 157, 156, 0, 66, 213, 122, 145, 174, 8}, }, { name: "MultiplePageZeroesSHA256SameFile", data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, - expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185}, + expectedHash: []byte{131, 122, 73, 143, 4, 202, 193, 156, 218, 169, 196, 223, 70, 100, 117, 191, 241, 113, 134, 11, 229, 231, 105, 157, 156, 0, 66, 213, 122, 145, 174, 8}, }, { name: "MultiplePageZeroesSHA512SeparateFile", data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, - expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81}, + expectedHash: []byte{211, 48, 232, 110, 240, 51, 99, 241, 123, 138, 42, 76, 94, 86, 59, 200, 3, 246, 137, 148, 189, 226, 111, 103, 146, 29, 12, 218, 40, 182, 33, 99, 193, 163, 238, 26, 184, 13, 165, 187, 68, 173, 139, 9, 208, 59, 0, 192, 180, 50, 221, 35, 43, 119, 194, 16, 64, 84, 116, 63, 158, 195, 194, 226}, }, { name: "MultiplePageZeroesSHA512SameFile", data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, - expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81}, + expectedHash: []byte{211, 48, 232, 110, 240, 51, 99, 241, 123, 138, 42, 76, 94, 86, 59, 200, 3, 246, 137, 148, 189, 226, 111, 103, 146, 29, 12, 218, 40, 182, 33, 99, 193, 163, 238, 26, 184, 13, 165, 187, 68, 173, 139, 9, 208, 59, 0, 192, 180, 50, 221, 35, 43, 119, 194, 16, 64, 84, 116, 63, 158, 195, 194, 226}, }, { name: "SingleASHA256SeparateFile", data: []byte{'a'}, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, - expectedHash: []byte{90, 124, 194, 100, 206, 242, 75, 152, 47, 249, 16, 27, 136, 161, 223, 228, 121, 241, 126, 158, 126, 122, 100, 120, 117, 15, 81, 78, 201, 133, 119, 111}, + expectedHash: []byte{26, 47, 238, 138, 235, 244, 140, 231, 129, 240, 155, 252, 219, 44, 46, 72, 57, 249, 139, 88, 132, 238, 86, 108, 181, 115, 96, 72, 99, 210, 134, 47}, }, { name: "SingleASHA256SameFile", data: []byte{'a'}, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, - expectedHash: []byte{90, 124, 194, 100, 206, 242, 75, 152, 47, 249, 16, 27, 136, 161, 223, 228, 121, 241, 126, 158, 126, 122, 100, 120, 117, 15, 81, 78, 201, 133, 119, 111}, + expectedHash: []byte{26, 47, 238, 138, 235, 244, 140, 231, 129, 240, 155, 252, 219, 44, 46, 72, 57, 249, 139, 88, 132, 238, 86, 108, 181, 115, 96, 72, 99, 210, 134, 47}, }, { name: "SingleASHA512SeparateFile", data: []byte{'a'}, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, - expectedHash: []byte{24, 10, 13, 25, 113, 62, 169, 99, 151, 70, 166, 113, 81, 81, 163, 85, 5, 25, 29, 15, 46, 37, 104, 120, 142, 218, 52, 178, 187, 83, 30, 166, 101, 87, 70, 196, 188, 61, 123, 20, 13, 254, 126, 52, 212, 111, 75, 203, 33, 233, 233, 47, 181, 161, 43, 193, 131, 41, 99, 33, 164, 73, 89, 152}, + expectedHash: []byte{44, 30, 224, 12, 102, 119, 163, 171, 119, 175, 212, 121, 231, 188, 125, 171, 79, 28, 144, 234, 75, 122, 44, 75, 15, 101, 173, 92, 233, 109, 234, 60, 173, 148, 125, 85, 94, 234, 95, 91, 16, 196, 88, 175, 23, 129, 226, 110, 24, 238, 5, 49, 186, 128, 72, 188, 193, 180, 207, 193, 203, 119, 40, 191}, }, { name: "SingleASHA512SameFile", data: []byte{'a'}, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, - expectedHash: []byte{24, 10, 13, 25, 113, 62, 169, 99, 151, 70, 166, 113, 81, 81, 163, 85, 5, 25, 29, 15, 46, 37, 104, 120, 142, 218, 52, 178, 187, 83, 30, 166, 101, 87, 70, 196, 188, 61, 123, 20, 13, 254, 126, 52, 212, 111, 75, 203, 33, 233, 233, 47, 181, 161, 43, 193, 131, 41, 99, 33, 164, 73, 89, 152}, + expectedHash: []byte{44, 30, 224, 12, 102, 119, 163, 171, 119, 175, 212, 121, 231, 188, 125, 171, 79, 28, 144, 234, 75, 122, 44, 75, 15, 101, 173, 92, 233, 109, 234, 60, 173, 148, 125, 85, 94, 234, 95, 91, 16, 196, 88, 175, 23, 129, 226, 110, 24, 238, 5, 49, 186, 128, 72, 188, 193, 180, 207, 193, 203, 119, 40, 191}, }, { name: "OnePageASHA256SeparateFile", data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, - expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1}, + expectedHash: []byte{166, 254, 83, 46, 241, 111, 18, 47, 79, 6, 181, 197, 176, 143, 211, 204, 53, 5, 245, 134, 172, 95, 97, 131, 236, 132, 197, 138, 123, 78, 43, 13}, }, { name: "OnePageASHA256SameFile", data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, - expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1}, + expectedHash: []byte{166, 254, 83, 46, 241, 111, 18, 47, 79, 6, 181, 197, 176, 143, 211, 204, 53, 5, 245, 134, 172, 95, 97, 131, 236, 132, 197, 138, 123, 78, 43, 13}, }, { name: "OnePageASHA512SeparateFile", data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: false, - expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145}, + expectedHash: []byte{23, 69, 6, 79, 39, 232, 90, 246, 62, 55, 4, 229, 47, 36, 230, 24, 233, 47, 55, 36, 26, 139, 196, 78, 242, 12, 194, 77, 109, 81, 151, 188, 63, 201, 127, 235, 81, 214, 91, 200, 19, 232, 240, 14, 197, 1, 99, 224, 18, 213, 203, 242, 44, 102, 25, 62, 90, 189, 106, 107, 129, 61, 115, 39}, }, { name: "OnePageASHA512SameFile", data: bytes.Repeat([]byte{'a'}, hostarch.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, - expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145}, + expectedHash: []byte{23, 69, 6, 79, 39, 232, 90, 246, 62, 55, 4, 229, 47, 36, 230, 24, 233, 47, 55, 36, 26, 139, 196, 78, 242, 12, 194, 77, 109, 81, 151, 188, 63, 201, 127, 235, 81, 214, 91, 200, 19, 232, 240, 14, 197, 1, 99, 224, 18, 213, 203, 242, 44, 102, 25, 62, 90, 189, 106, 107, 129, 61, 115, 39}, }, } @@ -324,7 +324,7 @@ func TestGenerate(t *testing.T) { Mode: defaultMode, UID: defaultUID, GID: defaultGID, - Children: make(map[string]struct{}), + Children: []string{}, HashAlgorithms: tc.hashAlgorithms, TreeReader: &tree, TreeWriter: &tree, @@ -366,7 +366,7 @@ func prepareVerify(t *testing.T, dataSize int64, hashAlgorithm int, dataAndTreeI Mode: defaultMode, UID: defaultUID, GID: defaultGID, - Children: make(map[string]struct{}), + Children: []string{}, HashAlgorithms: hashAlgorithm, TreeReader: &tree, TreeWriter: &tree, @@ -398,7 +398,7 @@ func prepareVerify(t *testing.T, dataSize int64, hashAlgorithm int, dataAndTreeI Mode: defaultMode, UID: defaultUID, GID: defaultGID, - Children: make(map[string]struct{}), + Children: []string{}, HashAlgorithms: hashAlgorithm, ReadOffset: verifyStart, ReadSize: verifySize, @@ -627,7 +627,7 @@ func TestVerifyModifiedChildren(t *testing.T) { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf) - params.Children["abc"] = struct{}{} + params.Children = append(params.Children, "abc") if _, err := Verify(¶ms); errors.Is(err, nil) { t.Errorf("Verification succeeded when expected to fail") } diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index fdeee3a5f..4829ae7ce 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -36,17 +36,22 @@ var ( // new metric after initialization. ErrInitializationDone = errors.New("metric cannot be created after initialization is complete") - // createdSentryMetrics indicates that the sentry metrics are created. - createdSentryMetrics = false - // WeirdnessMetric is a metric with fields created to track the number // of weird occurrences such as time fallback, partial_result, vsyscall // count, watchdog startup timeouts and stuck tasks. - WeirdnessMetric *Uint64Metric + WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", + Field{ + name: "weirdness_type", + allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, + }) // SuspiciousOperationsMetric is a metric with fields created to detect // operations such as opening an executable file to write from a gofer. - SuspiciousOperationsMetric *Uint64Metric + SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", + Field{ + name: "operation_type", + allowedValues: []string{"opened_write_execute_file"}, + }) ) // Uint64Metric encapsulates a uint64 that represents some kind of metric to be @@ -84,17 +89,21 @@ var ( // Precondition: // * All metrics are registered. // * Initialize/Disable has not been called. -func Initialize() { +func Initialize() error { if initialized { - panic("Initialize/Disable called more than once") + return errors.New("metric.Initialize called after metric.Initialize or metric.Disable") } - initialized = true m := pb.MetricRegistration{} for _, v := range allMetrics.m { m.Metrics = append(m.Metrics, v.metadata) } - eventchannel.Emit(&m) + if err := eventchannel.Emit(&m); err != nil { + return fmt.Errorf("unable to emit metric initialize event: %w", err) + } + + initialized = true + return nil } // Disable sends an empty metric registration event over the event channel, @@ -103,16 +112,18 @@ func Initialize() { // Precondition: // * All metrics are registered. // * Initialize/Disable has not been called. -func Disable() { +func Disable() error { if initialized { - panic("Initialize/Disable called more than once") + return errors.New("metric.Disable called after metric.Initialize or metric.Disable") } - initialized = true m := pb.MetricRegistration{} if err := eventchannel.Emit(&m); err != nil { - panic("unable to emit metric disable event: " + err.Error()) + return fmt.Errorf("unable to emit metric disable event: %w", err) } + + initialized = true + return nil } type customUint64Metric struct { @@ -165,8 +176,8 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met } // Metrics can exist without fields. - if len(fields) > 1 { - panic("Sentry metrics support at most one field") + if l := len(fields); l > 1 { + return fmt.Errorf("%d fields provided, must be <= 1", l) } for _, field := range fields { @@ -182,7 +193,7 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met // without fields and panics if it returns an error. func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) { if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil { - panic(fmt.Sprintf("Unable to register metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to register metric %q: %s", name, err)) } } @@ -209,7 +220,7 @@ func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, desc func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric { m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...) if err != nil { - panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to create metric %q: %s", name, err)) } return m } @@ -219,7 +230,7 @@ func MustCreateNewUint64Metric(name string, sync bool, description string, field func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric { m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description) if err != nil { - panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to create metric %q: %s", name, err)) } return m } @@ -354,7 +365,7 @@ func EmitMetricUpdate() { m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, - Value: &pb.MetricValue_Uint64Value{t}, + Value: &pb.MetricValue_Uint64Value{Uint64Value: t}, }) case map[string]uint64: for fieldValue, metricValue := range t { @@ -369,7 +380,7 @@ func EmitMetricUpdate() { m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, FieldValues: []string{fieldValue}, - Value: &pb.MetricValue_Uint64Value{metricValue}, + Value: &pb.MetricValue_Uint64Value{Uint64Value: metricValue}, }) } } @@ -390,26 +401,7 @@ func EmitMetricUpdate() { } } - eventchannel.Emit(&m) -} - -// CreateSentryMetrics creates the sentry metrics during kernel initialization. -func CreateSentryMetrics() { - if createdSentryMetrics { - return + if err := eventchannel.Emit(&m); err != nil { + log.Warningf("Unable to emit metrics: %s", err) } - - createdSentryMetrics = true - - WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", - Field{ - name: "weirdness_type", - allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, - }) - - SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", - Field{ - name: "operation_type", - allowedValues: []string{"opened_write_execute_file"}, - }) } diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index c71dfd460..1b4a9e73a 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -48,6 +48,8 @@ func (s *sliceEmitter) Reset() { var emitter sliceEmitter func init() { + reset() + eventchannel.AddEmitter(&emitter) } @@ -77,7 +79,9 @@ func TestInitialize(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } if len(emitter) != 1 { t.Fatalf("Initialize emitted %d events want 1", len(emitter)) @@ -149,7 +153,9 @@ func TestDisable(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Disable() + if err := Disable(); err != nil { + t.Fatalf("Disable(): %s", err) + } if len(emitter) != 1 { t.Fatalf("Initialize emitted %d events want 1", len(emitter)) @@ -178,7 +184,9 @@ func TestEmitMetricUpdate(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } // Don't care about the registration metrics. emitter.Reset() @@ -270,7 +278,9 @@ func TestEmitMetricUpdateWithFields(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } // Don't care about the registration metrics. emitter.Reset() diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index b2291ef97..2b22b2203 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -22,6 +22,9 @@ go_library( "version.go", ], deps = [ + "//pkg/abi/linux/errno", + "//pkg/errors", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdchannel", "//pkg/flipcall", diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 758e11b13..161b451cc 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -23,6 +23,9 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" ) @@ -62,6 +65,45 @@ func newErr(err error) *Rlerror { return &Rlerror{Error: uint32(ExtractErrno(err))} } +// ExtractLinuxerrErrno extracts a *errors.Error from a error, best effort. +// TODO(b/34162363): Merge this with ExtractErrno. +func ExtractLinuxerrErrno(err error) *errors.Error { + switch err { + case os.ErrNotExist: + return linuxerr.ENOENT + case os.ErrExist: + return linuxerr.EEXIST + case os.ErrPermission: + return linuxerr.EACCES + case os.ErrInvalid: + return linuxerr.EINVAL + } + + // Attempt to unwrap. + switch e := err.(type) { + case *errors.Error: + return e + case unix.Errno: + return linuxerr.ErrorFromErrno(errno.Errno(e)) + case *os.PathError: + return ExtractLinuxerrErrno(e.Err) + case *os.SyscallError: + return ExtractLinuxerrErrno(e.Err) + case *os.LinkError: + return ExtractLinuxerrErrno(e.Err) + } + + // Default case. + log.Warningf("unknown error: %v", err) + return linuxerr.EIO +} + +// newErrFromLinuxerr returns an Rlerror from the linuxerr list. +// TODO(b/34162363): Merge this with newErr. +func newErrFromLinuxerr(err error) *Rlerror { + return &Rlerror{Error: uint32(ExtractLinuxerrErrno(err).Errno())} +} + // handler is implemented for server-handled messages. // // See server.go for call information. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index ff1172ed6..241ab44ef 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -19,7 +19,8 @@ import ( "runtime/debug" "sync/atomic" - "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdchannel" "gvisor.dev/gvisor/pkg/flipcall" @@ -483,7 +484,7 @@ func (cs *connState) lookupChannel(id uint32) *channel { func (cs *connState) handle(m message) (r message) { if !cs.reqGate.Enter() { // connState.stop() has been called; the connection is shutting down. - r = newErr(unix.ECONNRESET) + r = newErrFromLinuxerr(linuxerr.ECONNRESET) return } defer func() { @@ -498,15 +499,23 @@ func (cs *connState) handle(m message) (r message) { // Wrap in an EFAULT error; we don't really have a // better way to describe this kind of error. It will // usually manifest as a result of the test framework. - r = newErr(unix.EFAULT) + r = newErrFromLinuxerr(linuxerr.EFAULT) } }() if handler, ok := m.(handler); ok { // Call the message handler. r = handler.handle(cs) + // TODO(b/34162363):This is only here to make sure the server works with + // only linuxerr Errors, as the handlers work with both client and server. + // It will be removed a followup, when all the unix.Errno errors are + // replaced with linuxerr. + if rlError, ok := r.(*Rlerror); ok { + e := linuxerr.ErrorFromErrno(errno.Errno(rlError.Error)) + r = newErrFromLinuxerr(e) + } } else { // Produce an ENOSYS error. - r = newErr(unix.ENOSYS) + r = newErrFromLinuxerr(linuxerr.ENOSYS) } return } @@ -553,7 +562,7 @@ func (cs *connState) handleRequest() bool { // If it's not a connection error, but some other protocol error, // we can send a response immediately. cs.sendMu.Lock() - err := send(cs.conn, tag, newErr(err)) + err := send(cs.conn, tag, newErrFromLinuxerr(err)) cs.sendMu.Unlock() if err != nil { log.Debugf("p9.send: %v", err) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 4aecb8007..1bbcae045 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -261,8 +261,8 @@ func (l *LeakMode) Get() interface{} { } // String implements flag.Value. -func (l *LeakMode) String() string { - switch *l { +func (l LeakMode) String() string { + switch l { case UninitializedLeakChecking: return "uninitialized" case NoLeakChecking: @@ -272,7 +272,7 @@ func (l *LeakMode) String() string { case LeaksLogTraces: return "log-traces" } - panic(fmt.Sprintf("invalid ref leak mode %d", *l)) + panic(fmt.Sprintf("invalid ref leak mode %d", l)) } // leakMode stores the current mode for the reference leak checker. diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 41dfd0bf9..f63af8b76 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -254,6 +254,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { return } +var sentryXCR0 = xgetbv(0) + // start is the CPU entrypoint. // // This is called from the Start asm stub (see entry_amd64.go); on return the @@ -265,24 +267,10 @@ func start(c *CPU) { WriteGS(kernelAddr(c.kernelEntry)) WriteFS(uintptr(c.registers.Fs_base)) - // Initialize floating point. - // - // Note that on skylake, the valid XCR0 mask reported seems to be 0xff. - // This breaks down as: - // - // bit0 - x87 - // bit1 - SSE - // bit2 - AVX - // bit3-4 - MPX - // bit5-7 - AVX512 - // - // For some reason, enabled MPX & AVX512 on platforms that report them - // seems to be cause a general protection fault. (Maybe there are some - // virtualization issues and these aren't exported to the guest cpuid.) - // This needs further investigation, but we can limit the floating - // point operations to x87, SSE & AVX for now. fninit() - xsetbv(0, validXCR0Mask&0x7) + // Need to sync XCR0 with the host, because xsave and xrstor can be + // called from different contexts. + xsetbv(0, sentryXCR0) // Set the syscall target. wrmsr(_MSR_LSTAR, kernelFunc(sysenter)) diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD index b77c40279..db5787302 100644 --- a/pkg/safecopy/BUILD +++ b/pkg/safecopy/BUILD @@ -18,6 +18,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", "//pkg/syserror", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index efbc2ddc1..2365b2c0d 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -20,6 +20,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" ) // maxRegisterSize is the maximum register size used in memcpy and memclr. It @@ -342,12 +343,7 @@ func errorFromFaultSignal(addr uintptr, sig int32) error { // handler however, and if this is function is being used externally then the // same courtesy is expected. func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { - var sa struct { - handler uintptr - flags uint64 - restorer uintptr - mask uint64 - } + var sa linux.SigAction const maskLen = 8 // Get the existing signal handler information, and save the current @@ -358,14 +354,14 @@ func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) e } // Fail if there isn't a previous handler. - if sa.handler == 0 { + if sa.Handler == 0 { return fmt.Errorf("previous handler for signal %x isn't set", sig) } - *previous = sa.handler + *previous = uintptr(sa.Handler) // Install our own handler. - sa.handler = handler + sa.Handler = uint64(handler) if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { return e } diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index daea51c4d..8ffa1db37 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -36,14 +36,10 @@ const ( // Install generates BPF code based on the set of syscalls provided. It only // allows syscalls that conform to the specification. Syscalls that violate the -// specification will trigger RET_KILL_PROCESS, except for the cases below. -// -// RET_TRAP is used in violations, instead of RET_KILL_PROCESS, in the -// following cases: -// 1. Kernel doesn't support RET_KILL_PROCESS: RET_KILL_THREAD only kills the -// offending thread and often keeps the sentry hanging. -// 2. Debug: RET_TRAP generates a panic followed by a stack trace which is -// much easier to debug then RET_KILL_PROCESS which can't be caught. +// specification will trigger RET_KILL_PROCESS. If RET_KILL_PROCESS is not +// supported, violations will trigger RET_TRAP instead. RET_KILL_THREAD is not +// used because it only kills the offending thread and often keeps the sentry +// hanging. // // Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored, // making it possible for the process to continue running after a violation. diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index c9c52530d..068a0c8d9 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -14,12 +14,8 @@ go_library( "arch_x86.go", "arch_x86_impl.go", "auxv.go", - "signal.go", - "signal_act.go", "signal_amd64.go", "signal_arm64.go", - "signal_info.go", - "signal_stack.go", "stack.go", "stack_unsafe.go", "syscalls_amd64.go", @@ -32,6 +28,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/cpuid", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/marshal", diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 290863ee6..c9393b091 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -134,21 +134,13 @@ type Context interface { // RegisterMap returns a map of all registers. RegisterMap() (map[string]uintptr, error) - // NewSignalAct returns a new object that is equivalent to struct sigaction - // in the guest architecture. - NewSignalAct() NativeSignalAct - - // NewSignalStack returns a new object that is equivalent to stack_t in the - // guest architecture. - NewSignalStack() NativeSignalStack - // SignalSetup modifies the context in preparation for handling the // given signal. // // st is the stack where the signal handler frame should be // constructed. // - // act is the SignalAct that specifies how this signal is being + // act is the SigAction that specifies how this signal is being // handled. // // info is the SignalInfo of the signal being delivered. @@ -157,7 +149,7 @@ type Context interface { // stack is not going to be used). // // sigset is the signal mask before entering the signal handler. - SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error + SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error // SignalRestore restores context after returning from a signal // handler. @@ -167,7 +159,7 @@ type Context interface { // rt is true if SignalRestore is being entered from rt_sigreturn and // false if SignalRestore is being entered from sigreturn. // SignalRestore returns the thread's new signal mask. - SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) + SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) // CPUIDEmulate emulates a CPUID instruction according to current register state. CPUIDEmulate(l log.Logger) diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 08789f517..7def71fef 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" @@ -237,7 +238,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, } return s.PtraceGetRegs(dst) default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } @@ -250,7 +251,7 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, } return s.PtraceSetRegs(src) default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index e8e52d3a8..d13e12f8c 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -23,6 +23,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" @@ -361,7 +362,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, case _NT_X86_XSTATE: return s.fpState.PtraceGetXstateRegs(dst, maxlen, s.FeatureSet) default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } @@ -378,7 +379,7 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, case _NT_X86_XSTATE: return s.fpState.PtraceSetXstateRegs(src, maxlen, s.FeatureSet) default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go deleted file mode 100644 index 67d7edf68..000000000 --- a/pkg/sentry/arch/signal.go +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arch - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/hostarch" -) - -// SignalAct represents the action that should be taken when a signal is -// delivered, and is equivalent to struct sigaction. -// -// +marshal -// +stateify savable -type SignalAct struct { - Handler uint64 - Flags uint64 - Restorer uint64 // Only used on amd64. - Mask linux.SignalSet -} - -// SerializeFrom implements NativeSignalAct.SerializeFrom. -func (s *SignalAct) SerializeFrom(other *SignalAct) { - *s = *other -} - -// DeserializeTo implements NativeSignalAct.DeserializeTo. -func (s *SignalAct) DeserializeTo(other *SignalAct) { - *other = *s -} - -// SignalStack represents information about a user stack, and is equivalent to -// stack_t. -// -// +marshal -// +stateify savable -type SignalStack struct { - Addr uint64 - Flags uint32 - _ uint32 - Size uint64 -} - -// SerializeFrom implements NativeSignalStack.SerializeFrom. -func (s *SignalStack) SerializeFrom(other *SignalStack) { - *s = *other -} - -// DeserializeTo implements NativeSignalStack.DeserializeTo. -func (s *SignalStack) DeserializeTo(other *SignalStack) { - *other = *s -} - -// SignalInfo represents information about a signal being delivered, and is -// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h). -// -// +marshal -// +stateify savable -type SignalInfo struct { - Signo int32 // Signal number - Errno int32 // Errno value - Code int32 // Signal code - _ uint32 - - // struct siginfo::_sifields is a union. In SignalInfo, fields in the union - // are accessed through methods. - // - // For reference, here is the definition of _sifields: (_sigfault._trapno, - // which does not exist on x86, omitted for clarity) - // - // union { - // int _pad[SI_PAD_SIZE]; - // - // /* kill() */ - // struct { - // __kernel_pid_t _pid; /* sender's pid */ - // __ARCH_SI_UID_T _uid; /* sender's uid */ - // } _kill; - // - // /* POSIX.1b timers */ - // struct { - // __kernel_timer_t _tid; /* timer id */ - // int _overrun; /* overrun count */ - // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)]; - // sigval_t _sigval; /* same as below */ - // int _sys_private; /* not to be passed to user */ - // } _timer; - // - // /* POSIX.1b signals */ - // struct { - // __kernel_pid_t _pid; /* sender's pid */ - // __ARCH_SI_UID_T _uid; /* sender's uid */ - // sigval_t _sigval; - // } _rt; - // - // /* SIGCHLD */ - // struct { - // __kernel_pid_t _pid; /* which child */ - // __ARCH_SI_UID_T _uid; /* sender's uid */ - // int _status; /* exit code */ - // __ARCH_SI_CLOCK_T _utime; - // __ARCH_SI_CLOCK_T _stime; - // } _sigchld; - // - // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */ - // struct { - // void *_addr; /* faulting insn/memory ref. */ - // short _addr_lsb; /* LSB of the reported address */ - // } _sigfault; - // - // /* SIGPOLL */ - // struct { - // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */ - // int _fd; - // } _sigpoll; - // - // /* SIGSYS */ - // struct { - // void *_call_addr; /* calling user insn */ - // int _syscall; /* triggering system call number */ - // unsigned int _arch; /* AUDIT_ARCH_* of syscall */ - // } _sigsys; - // } _sifields; - // - // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128 - // bytes. - Fields [128 - 16]byte -} - -// FixSignalCodeForUser fixes up si_code. -// -// The si_code we get from Linux may contain the kernel-specific code in the -// top 16 bits if it's positive (e.g., from ptrace). Linux's -// copy_siginfo_to_user does -// err |= __put_user((short)from->si_code, &to->si_code); -// to mask out those bits and we need to do the same. -func (s *SignalInfo) FixSignalCodeForUser() { - if s.Code > 0 { - s.Code &= 0x0000ffff - } -} - -// PID returns the si_pid field. -func (s *SignalInfo) PID() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4])) -} - -// SetPID mutates the si_pid field. -func (s *SignalInfo) SetPID(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) -} - -// UID returns the si_uid field. -func (s *SignalInfo) UID() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) -} - -// SetUID mutates the si_uid field. -func (s *SignalInfo) SetUID(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) -} - -// Sigval returns the sigval field, which is aliased to both si_int and si_ptr. -func (s *SignalInfo) Sigval() uint64 { - return hostarch.ByteOrder.Uint64(s.Fields[8:16]) -} - -// SetSigval mutates the sigval field. -func (s *SignalInfo) SetSigval(val uint64) { - hostarch.ByteOrder.PutUint64(s.Fields[8:16], val) -} - -// TimerID returns the si_timerid field. -func (s *SignalInfo) TimerID() linux.TimerID { - return linux.TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4])) -} - -// SetTimerID sets the si_timerid field. -func (s *SignalInfo) SetTimerID(val linux.TimerID) { - hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) -} - -// Overrun returns the si_overrun field. -func (s *SignalInfo) Overrun() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) -} - -// SetOverrun sets the si_overrun field. -func (s *SignalInfo) SetOverrun(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) -} - -// Addr returns the si_addr field. -func (s *SignalInfo) Addr() uint64 { - return hostarch.ByteOrder.Uint64(s.Fields[0:8]) -} - -// SetAddr sets the si_addr field. -func (s *SignalInfo) SetAddr(val uint64) { - hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) -} - -// Status returns the si_status field. -func (s *SignalInfo) Status() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) -} - -// SetStatus mutates the si_status field. -func (s *SignalInfo) SetStatus(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) -} - -// CallAddr returns the si_call_addr field. -func (s *SignalInfo) CallAddr() uint64 { - return hostarch.ByteOrder.Uint64(s.Fields[0:8]) -} - -// SetCallAddr mutates the si_call_addr field. -func (s *SignalInfo) SetCallAddr(val uint64) { - hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) -} - -// Syscall returns the si_syscall field. -func (s *SignalInfo) Syscall() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) -} - -// SetSyscall mutates the si_syscall field. -func (s *SignalInfo) SetSyscall(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) -} - -// Arch returns the si_arch field. -func (s *SignalInfo) Arch() uint32 { - return hostarch.ByteOrder.Uint32(s.Fields[12:16]) -} - -// SetArch mutates the si_arch field. -func (s *SignalInfo) SetArch(val uint32) { - hostarch.ByteOrder.PutUint32(s.Fields[12:16], val) -} - -// Band returns the si_band field. -func (s *SignalInfo) Band() int64 { - return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8])) -} - -// SetBand mutates the si_band field. -func (s *SignalInfo) SetBand(val int64) { - // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. - // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is - // `int`. See siginfo.h. - hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) -} - -// FD returns the si_fd field. -func (s *SignalInfo) FD() uint32 { - return hostarch.ByteOrder.Uint32(s.Fields[8:12]) -} - -// SetFD mutates the si_fd field. -func (s *SignalInfo) SetFD(val uint32) { - hostarch.ByteOrder.PutUint32(s.Fields[8:12], val) -} diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go deleted file mode 100644 index d3e2324a8..000000000 --- a/pkg/sentry/arch/signal_act.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arch - -import "gvisor.dev/gvisor/pkg/marshal" - -// Special values for SignalAct.Handler. -const ( - // SignalActDefault is SIG_DFL and specifies that the default behavior for - // a signal should be taken. - SignalActDefault = 0 - - // SignalActIgnore is SIG_IGN and specifies that a signal should be - // ignored. - SignalActIgnore = 1 -) - -// Available signal flags. -const ( - SignalFlagNoCldStop = 0x00000001 - SignalFlagNoCldWait = 0x00000002 - SignalFlagSigInfo = 0x00000004 - SignalFlagRestorer = 0x04000000 - SignalFlagOnStack = 0x08000000 - SignalFlagRestart = 0x10000000 - SignalFlagInterrupt = 0x20000000 - SignalFlagNoDefer = 0x40000000 - SignalFlagResetHandler = 0x80000000 -) - -// IsSigInfo returns true iff this handle expects siginfo. -func (s SignalAct) IsSigInfo() bool { - return s.Flags&SignalFlagSigInfo != 0 -} - -// IsNoDefer returns true iff this SignalAct has the NoDefer flag set. -func (s SignalAct) IsNoDefer() bool { - return s.Flags&SignalFlagNoDefer != 0 -} - -// IsRestart returns true iff this SignalAct has the Restart flag set. -func (s SignalAct) IsRestart() bool { - return s.Flags&SignalFlagRestart != 0 -} - -// IsResetHandler returns true iff this SignalAct has the ResetHandler flag set. -func (s SignalAct) IsResetHandler() bool { - return s.Flags&SignalFlagResetHandler != 0 -} - -// IsOnStack returns true iff this SignalAct has the OnStack flag set. -func (s SignalAct) IsOnStack() bool { - return s.Flags&SignalFlagOnStack != 0 -} - -// HasRestorer returns true iff this SignalAct has the Restorer flag set. -func (s SignalAct) HasRestorer() bool { - return s.Flags&SignalFlagRestorer != 0 -} - -// NativeSignalAct is a type that is equivalent to struct sigaction in the -// guest architecture. -type NativeSignalAct interface { - marshal.Marshallable - - // SerializeFrom copies the data in the host SignalAct s into this object. - SerializeFrom(s *SignalAct) - - // DeserializeTo copies the data in this object into the host SignalAct s. - DeserializeTo(s *SignalAct) -} diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 082ed92b1..58e28dbba 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -76,21 +76,11 @@ const ( type UContext64 struct { Flags uint64 Link uint64 - Stack SignalStack + Stack linux.SignalStack MContext SignalContext64 Sigset linux.SignalSet } -// NewSignalAct implements Context.NewSignalAct. -func (c *context64) NewSignalAct() NativeSignalAct { - return &SignalAct{} -} - -// NewSignalStack implements Context.NewSignalStack. -func (c *context64) NewSignalStack() NativeSignalStack { - return &SignalStack{} -} - // From Linux 'arch/x86/include/uapi/asm/sigcontext.h' the following is the // size of the magic cookie at the end of the xsave frame. // @@ -110,7 +100,7 @@ func (c *context64) fpuFrameSize() (size int, useXsave bool) { // SignalSetup implements Context.SignalSetup. (Compare to Linux's // arch/x86/kernel/signal.c:__setup_rt_frame().) -func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error { +func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error { sp := st.Bottom // "The 128-byte area beyond the location pointed to by %rsp is considered @@ -187,7 +177,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // Prior to proceeding, figure out if the frame will exhaust the range // for the signal stack. This is not allowed, and should immediately // force signal delivery (reverting to the default handler). - if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) { + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) { return unix.EFAULT } @@ -203,7 +193,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt return err } ucAddr := st.Bottom - if act.HasRestorer() { + if act.Flags&linux.SA_RESTORER != 0 { // Push the restorer return address. // Note that this doesn't need to be popped. if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil { @@ -237,15 +227,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // SignalRestore implements Context.SignalRestore. (Compare to Linux's // arch/x86/kernel/signal.c:sys_rt_sigreturn().) -func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { +func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) { // Copy out the stack frame. var uc UContext64 if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } - var info SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } // Restore registers. diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index da71fb873..80df90076 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -61,7 +61,7 @@ type FpsimdContext struct { type UContext64 struct { Flags uint64 Link uint64 - Stack SignalStack + Stack linux.SignalStack Sigset linux.SignalSet // glibc uses a 1024-bit sigset_t _pad [120]byte // (1024 - 64) / 8 = 120 @@ -71,18 +71,8 @@ type UContext64 struct { MContext SignalContext64 } -// NewSignalAct implements Context.NewSignalAct. -func (c *context64) NewSignalAct() NativeSignalAct { - return &SignalAct{} -} - -// NewSignalStack implements Context.NewSignalStack. -func (c *context64) NewSignalStack() NativeSignalStack { - return &SignalStack{} -} - // SignalSetup implements Context.SignalSetup. -func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error { +func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error { sp := st.Bottom // Construct the UContext64 now since we need its size. @@ -114,7 +104,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // Prior to proceeding, figure out if the frame will exhaust the range // for the signal stack. This is not allowed, and should immediately // force signal delivery (reverting to the default handler). - if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) { + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) { return unix.EFAULT } @@ -137,7 +127,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt c.Regs.Regs[0] = uint64(info.Signo) c.Regs.Regs[1] = uint64(infoAddr) c.Regs.Regs[2] = uint64(ucAddr) - c.Regs.Regs[30] = uint64(act.Restorer) + c.Regs.Regs[30] = act.Restorer // Save the thread's floating point state. c.sigFPState = append(c.sigFPState, c.fpState) @@ -147,15 +137,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt } // SignalRestore implements Context.SignalRestore. -func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { +func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) { // Copy out the stack frame. var uc UContext64 if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } - var info SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } // Restore registers. diff --git a/pkg/sentry/arch/signal_info.go b/pkg/sentry/arch/signal_info.go deleted file mode 100644 index f93ee8b46..000000000 --- a/pkg/sentry/arch/signal_info.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arch - -// Possible values for SignalInfo.Code. These values originate from the Linux -// kernel's include/uapi/asm-generic/siginfo.h. -const ( - // SignalInfoUser (properly SI_USER) indicates that a signal was sent from - // a kill() or raise() syscall. - SignalInfoUser = 0 - - // SignalInfoKernel (properly SI_KERNEL) indicates that the signal was sent - // by the kernel. - SignalInfoKernel = 0x80 - - // SignalInfoTimer (properly SI_TIMER) indicates that the signal was sent - // by an expired timer. - SignalInfoTimer = -2 - - // SignalInfoTkill (properly SI_TKILL) indicates that the signal was sent - // from a tkill() or tgkill() syscall. - SignalInfoTkill = -6 - - // CLD_* codes are only meaningful for SIGCHLD. - - // CLD_EXITED indicates that a task exited. - CLD_EXITED = 1 - - // CLD_KILLED indicates that a task was killed by a signal. - CLD_KILLED = 2 - - // CLD_DUMPED indicates that a task was killed by a signal and then dumped - // core. - CLD_DUMPED = 3 - - // CLD_TRAPPED indicates that a task was stopped by ptrace. - CLD_TRAPPED = 4 - - // CLD_STOPPED indicates that a thread group completed a group stop. - CLD_STOPPED = 5 - - // CLD_CONTINUED indicates that a group-stopped thread group was continued. - CLD_CONTINUED = 6 - - // SYS_* codes are only meaningful for SIGSYS. - - // SYS_SECCOMP indicates that a signal originates from seccomp. - SYS_SECCOMP = 1 - - // TRAP_* codes are only meaningful for SIGTRAP. - - // TRAP_BRKPT indicates a breakpoint trap. - TRAP_BRKPT = 1 -) diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go deleted file mode 100644 index c732c7503..000000000 --- a/pkg/sentry/arch/signal_stack.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build 386 amd64 arm64 - -package arch - -import ( - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/marshal" -) - -const ( - // SignalStackFlagOnStack is possible set on return from getaltstack, - // in order to indicate that the thread is currently on the alt stack. - SignalStackFlagOnStack = 1 - - // SignalStackFlagDisable is a flag to indicate the stack is disabled. - SignalStackFlagDisable = 2 -) - -// IsEnabled returns true iff this signal stack is marked as enabled. -func (s SignalStack) IsEnabled() bool { - return s.Flags&SignalStackFlagDisable == 0 -} - -// Top returns the stack's top address. -func (s SignalStack) Top() hostarch.Addr { - return hostarch.Addr(s.Addr + s.Size) -} - -// SetOnStack marks this signal stack as in use. -// -// Note that there is no corresponding ClearOnStack, and that this should only -// be called on copies that are serialized to userspace. -func (s *SignalStack) SetOnStack() { - s.Flags |= SignalStackFlagOnStack -} - -// Contains checks if the stack pointer is within this stack. -func (s *SignalStack) Contains(sp hostarch.Addr) bool { - return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size) -} - -// NativeSignalStack is a type that is equivalent to stack_t in the guest -// architecture. -type NativeSignalStack interface { - marshal.Marshallable - - // SerializeFrom copies the data in the host SignalStack s into this - // object. - SerializeFrom(s *SignalStack) - - // DeserializeTo copies the data in this object into the host SignalStack - // s. - DeserializeTo(s *SignalStack) -} diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 65a794c7c..85e3515af 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -45,7 +45,7 @@ type Stack struct { } // scratchBufLen is the default length of Stack.scratchBuf. The -// largest structs the stack regularly serializes are arch.SignalInfo +// largest structs the stack regularly serializes are linux.SignalInfo // and arch.UContext64. We'll set the default size as the larger of // the two, arch.UContext64. var scratchBufLen = (*UContext64)(nil).SizeBytes() diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 367849e75..221e98a01 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -99,6 +99,9 @@ type ExecArgs struct { // PIDNamespace is the pid namespace for the process being executed. PIDNamespace *kernel.PIDNamespace + + // Limits is the limit set for the process being executed. + Limits *limits.LimitSet } // String prints the arguments as a string. @@ -151,6 +154,10 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI if pidns == nil { pidns = proc.Kernel.RootPIDNamespace() } + limitSet := args.Limits + if limitSet == nil { + limitSet = limits.NewLimitSet() + } initArgs := kernel.CreateProcessArgs{ Filename: args.Filename, Argv: args.Argv, @@ -161,7 +168,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI Credentials: creds, FDTable: fdTable, Umask: 0022, - Limits: limits.NewLimitSet(), + Limits: limitSet, MaxSymlinkTraversals: linux.MaxSymlinkTraversals, UTSNamespace: proc.Kernel.RootUTSNamespace(), IPCNamespace: proc.Kernel.RootIPCNamespace(), diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD index 8b38d574d..37229e7ba 100644 --- a/pkg/sentry/devices/tundev/BUILD +++ b/pkg/sentry/devices/tundev/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/devtmpfs", diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index a12eeb8e7..4ef91a600 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -18,6 +18,7 @@ package tundev import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" @@ -81,7 +82,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg } stack, ok := t.NetworkContext().(*netstack.Stack) if !ok { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } var req linux.IFReq diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 0dc100f9b..74adbfa55 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/abi/linux", "//pkg/amutex", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/p9", @@ -110,6 +111,7 @@ go_test( deps = [ ":fs", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go index b90f7c1be..4c99944e7 100644 --- a/pkg/sentry/fs/attr.go +++ b/pkg/sentry/fs/attr.go @@ -478,6 +478,20 @@ func (f FilePermissions) AnyRead() bool { return f.User.Read || f.Group.Read || f.Other.Read } +// HasSetUIDOrGID returns true if either the setuid or setgid bit is set. +func (f FilePermissions) HasSetUIDOrGID() bool { + return f.SetUID || f.SetGID +} + +// DropSetUIDAndMaybeGID turns off setuid, and turns off setgid if f allows +// group execution. +func (f *FilePermissions) DropSetUIDAndMaybeGID() { + f.SetUID = false + if f.Group.Execute { + f.SetGID = false + } +} + // FileOwner represents ownership of a file. // // +stateify savable diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 5aa668873..a8591052c 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -161,7 +162,7 @@ func doCopyUp(ctx context.Context, d *Dirent) error { // then try to take copyMu for writing here, we'd deadlock. t := d.Inode.overlay.lower.StableAttr.Type if t != RegularFile && t != Directory && t != Symlink { - return syserror.EINVAL + return linuxerr.EINVAL } // Wait to get exclusive access to the upper Inode. @@ -410,7 +411,7 @@ func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error return err } lowerXattr, err := lower.ListXattr(ctx, linux.XATTR_SIZE_MAX) - if err != nil && err != syserror.EOPNOTSUPP { + if err != nil && !linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { return err } diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 23a3a9a2d..e28a8961b 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/rand", "//pkg/safemem", diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index e84ba7a5d..c62effd52 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -16,6 +16,7 @@ package dev import ( + "fmt" "math" "gvisor.dev/gvisor/pkg/context" @@ -90,6 +91,11 @@ func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.In // New returns the root node of a device filesystem. func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { + shm, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */) + if err != nil { + panic(fmt.Sprintf("tmpfs.NewDir failed: %v", err)) + } + contents := map[string]*fs.Inode{ "fd": newSymlink(ctx, "/proc/self/fd", msrc), "stdin": newSymlink(ctx, "/proc/self/fd/0", msrc), @@ -108,7 +114,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "random": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, randomDevMinor), "urandom": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, urandomDevMinor), - "shm": tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc), + "shm": shm, // A devpts is typically mounted at /dev/pts to provide // pseudoterminal support. Place an empty directory there for diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 77e8d222a..5674978bd 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -17,6 +17,7 @@ package dev import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -102,7 +103,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user } stack, ok := t.NetworkContext().(*netstack.Stack) if !ok { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } var req linux.IFReq diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 9d5d40954..e21c9d78e 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" @@ -963,7 +964,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err // // See Linux equivalent in fs/namespace.c:do_add_mount. if IsSymlink(inode.StableAttr) { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } // Dirent that'll replace d. @@ -1439,7 +1440,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // replaced is the dirent that is being overwritten by rename. replaced, err := newParent.walk(ctx, root, newName, false /* may unlock */) if err != nil { - if err != syserror.ENOENT { + if !linuxerr.Equals(linuxerr.ENOENT, err) { return err } diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 2120f2bad..1bd2055d0 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -13,6 +13,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/log", @@ -38,6 +39,7 @@ go_test( library = ":fdpipe", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/hostarch", diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 757b7d511..f8a29816b 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" @@ -158,7 +159,7 @@ func (p *pipeOperations) Write(ctx context.Context, file *fs.File, src usermem.I // isBlockError unwraps os errors and checks if they are caused by EAGAIN or // EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { - if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK { + if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) { return true } if pe, ok := err.(*os.PathError); ok { diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index ab0e9dac7..6ea49cbb7 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -21,14 +21,14 @@ import ( "testing" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" ) func singlePipeFD() (int, error) { @@ -214,7 +214,7 @@ func TestPipeRequest(t *testing.T) { { desc: "Fsync on pipe returns EINVAL", context: &Fsync{}, - err: unix.EINVAL, + err: linuxerr.EINVAL, }, { desc: "Seek on pipe returns ESPIPE", diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 696613f3a..7e2f107e0 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -18,6 +18,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -417,7 +418,7 @@ func (f *overlayFileOperations) FifoSize(ctx context.Context, overlayFile *File) err = f.onTop(ctx, overlayFile, func(file *File, ops FileOperations) error { sz, ok := ops.(FifoSizer) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } rv, err = sz.FifoSize(ctx, file) return err @@ -432,11 +433,11 @@ func (f *overlayFileOperations) SetFifoSize(size int64) (rv int64, err error) { if f.upper == nil { // Named pipes cannot be copied up and changes to the lower are prohibited. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } sz, ok := f.upper.FileOperations.(FifoSizer) if !ok { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } return sz.SetFifoSize(size) } diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 6469cc3a9..ebc90b41f 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -76,6 +76,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/safemem", diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index dc9efa5df..c3525ba8e 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -18,6 +18,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -63,12 +64,12 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence, switch inode.StableAttr.Type { case fs.RegularFile, fs.SpecialFile, fs.BlockDevice: if offset < 0 { - return current, syserror.EINVAL + return current, linuxerr.EINVAL } return offset, nil case fs.Directory, fs.SpecialDirectory: if offset != 0 { - return current, syserror.EINVAL + return current, linuxerr.EINVAL } // SEEK_SET to 0 moves the directory "cursor" to the beginning. if dirCursor != nil { @@ -76,22 +77,22 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence, } return 0, nil default: - return current, syserror.EINVAL + return current, linuxerr.EINVAL } case fs.SeekCurrent: switch inode.StableAttr.Type { case fs.RegularFile, fs.SpecialFile, fs.BlockDevice: if current+offset < 0 { - return current, syserror.EINVAL + return current, linuxerr.EINVAL } return current + offset, nil case fs.Directory, fs.SpecialDirectory: if offset != 0 { - return current, syserror.EINVAL + return current, linuxerr.EINVAL } return current, nil default: - return current, syserror.EINVAL + return current, linuxerr.EINVAL } case fs.SeekEnd: switch inode.StableAttr.Type { @@ -103,14 +104,14 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence, } sz := uattr.Size if sz+offset < 0 { - return current, syserror.EINVAL + return current, linuxerr.EINVAL } return sz + offset, nil // FIXME(b/34778850): This is not universally correct. // Remove SpecialDirectory. case fs.SpecialDirectory: if offset != 0 { - return current, syserror.EINVAL + return current, linuxerr.EINVAL } // SEEK_END to 0 moves the directory "cursor" to the end. // @@ -121,12 +122,12 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence, // futile (EOF will always be the result). return fs.FileMaxOffset, nil default: - return current, syserror.EINVAL + return current, linuxerr.EINVAL } } // Not a valid seek request. - return current, syserror.EINVAL + return current, linuxerr.EINVAL } // FileGenericSeek implements fs.FileOperations.Seek for files that use a @@ -152,7 +153,7 @@ type FileNoSeek struct{} // Seek implements fs.FileOperations.Seek. func (FileNoSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // FilePipeSeek implements fs.FileOperations.Seek and can be used for files @@ -178,7 +179,7 @@ type FileNoFsync struct{} // Fsync implements fs.FileOperations.Fsync. func (FileNoFsync) Fsync(context.Context, *fs.File, int64, int64, fs.SyncType) error { - return syserror.EINVAL + return linuxerr.EINVAL } // FileNoopFsync implements fs.FileOperations.Fsync for files that don't need @@ -345,7 +346,7 @@ func NewFileStaticContentReader(b []byte) FileStaticContentReader { // Read implements fs.FileOperations.Read. func (scr *FileStaticContentReader) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset >= int64(len(scr.content)) { return 0, nil @@ -367,7 +368,7 @@ type FileNoRead struct{} // Read implements fs.FileOperations.Read. func (FileNoRead) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // FileNoWrite implements fs.FileOperations.Write to return EINVAL. @@ -375,7 +376,7 @@ type FileNoWrite struct{} // Write implements fs.FileOperations.Write. func (FileNoWrite) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // FileNoopRead implement fs.FileOperations.Read as a noop. diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index e1e38b498..8ac3738e9 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -155,12 +155,20 @@ func (h *HostMappable) DecRef(fr memmap.FileRange) { // T2: Appends to file causing it to grow // T2: Writes to mapped pages and COW happens // T1: Continues and wronly invalidates the page mapped in step above. -func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error { +func (h *HostMappable) Truncate(ctx context.Context, newSize int64, uattr fs.UnstableAttr) error { h.truncateMu.Lock() defer h.truncateMu.Unlock() mask := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: newSize} + + // Truncating a file clears privilege bits. + if uattr.Perms.HasSetUIDOrGID() { + mask.Perms = true + attr.Perms = uattr.Perms + attr.Perms.DropSetUIDAndMaybeGID() + } + if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil { return err } @@ -193,10 +201,17 @@ func (h *HostMappable) Allocate(ctx context.Context, offset int64, length int64) } // Write writes to the file backing this mappable. -func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { +func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64, uattr fs.UnstableAttr) (int64, error) { h.truncateMu.RLock() + defer h.truncateMu.RUnlock() n, err := src.CopyInTo(ctx, &writer{ctx: ctx, hostMappable: h, off: offset}) - h.truncateMu.RUnlock() + if n > 0 && uattr.Perms.HasSetUIDOrGID() { + mask := fs.AttrMask{Perms: true} + uattr.Perms.DropSetUIDAndMaybeGID() + if err := h.backingFile.SetMaskedAttributes(ctx, mask, uattr, false); err != nil { + return n, err + } + } return n, err } diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index 85e7e35db..bda07275d 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -17,6 +17,7 @@ package fsutil import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -376,7 +377,7 @@ func (InodeNotDirectory) RemoveDirectory(context.Context, *fs.Inode, string) err // Rename implements fs.FileOperations.Rename. func (InodeNotDirectory) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error { - return syserror.EINVAL + return linuxerr.EINVAL } // InodeNotSocket can be used by Inodes that are not sockets. @@ -392,7 +393,7 @@ type InodeNotTruncatable struct{} // Truncate implements fs.InodeOperations.Truncate. func (InodeNotTruncatable) Truncate(context.Context, *fs.Inode, int64) error { - return syserror.EINVAL + return linuxerr.EINVAL } // InodeIsDirTruncate implements fs.InodeOperations.Truncate for directories. @@ -416,7 +417,7 @@ type InodeNotRenameable struct{} // Rename implements fs.InodeOperations.Rename. func (InodeNotRenameable) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error { - return syserror.EINVAL + return linuxerr.EINVAL } // InodeNotOpenable can be used by Inodes that cannot be opened. diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 7856b354b..855029b84 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -310,6 +310,12 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode, now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: size} + if c.attr.Perms.HasSetUIDOrGID() { + masked.Perms = true + attr.Perms = c.attr.Perms + attr.Perms.DropSetUIDAndMaybeGID() + c.attr.Perms = attr.Perms + } if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil { c.dataMu.Unlock() return err @@ -685,13 +691,14 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return done, nil } -// maybeGrowFile grows the file's size if data has been written past the old -// size. +// maybeUpdateAttrs updates the file's attributes after a write. It updates +// size if data has been written past the old size, and setuid/setgid if any +// bytes were written. // // Preconditions: // * rw.c.attrMu must be locked. // * rw.c.dataMu must be locked. -func (rw *inodeReadWriter) maybeGrowFile() { +func (rw *inodeReadWriter) maybeUpdateAttrs(nwritten uint64) { // If the write ends beyond the file's previous size, it causes the // file to grow. if rw.offset > rw.c.attr.Size { @@ -705,6 +712,12 @@ func (rw *inodeReadWriter) maybeGrowFile() { rw.c.attr.Usage = rw.offset rw.c.dirtyAttr.Usage = true } + + // If bytes were written, ensure setuid and setgid are cleared. + if nwritten > 0 && rw.c.attr.Perms.HasSetUIDOrGID() { + rw.c.dirtyAttr.Perms = true + rw.c.attr.Perms.DropSetUIDAndMaybeGID() + } } // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. @@ -732,7 +745,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error segMR := seg.Range().Intersect(mr) ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write) if err != nil { - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, err } @@ -744,7 +757,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error srcs = srcs.DropFirst64(n) rw.c.dirty.MarkDirty(segMR) if err != nil { - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, err } @@ -765,7 +778,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error srcs = srcs.DropFirst64(n) // Partial writes are fine. But we must stop writing. if n != src.NumBytes() || err != nil { - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, err } @@ -774,7 +787,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error seg, gap = gap.NextSegment(), FileRangeGapIterator{} } } - rw.maybeGrowFile() + rw.maybeUpdateAttrs(done) rw.c.dataMu.Unlock() return done, nil } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 94cb05246..c08301d19 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -26,6 +26,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go index 07a564e92..f8b7a60fc 100644 --- a/pkg/sentry/fs/gofer/cache_policy.go +++ b/pkg/sentry/fs/gofer/cache_policy.go @@ -139,7 +139,7 @@ func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child // Walk from parent to child again. // - // TODO(b/112031682): If we have a directory FD in the parent + // NOTE(b/112031682): If we have a directory FD in the parent // inodeOperations, then we can use fstatat(2) to get the inode // attributes instead of making this RPC. qids, f, mask, attr, err := parentIops.fileState.file.walkGetAttr(ctx, []string{name}) diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index bcdb2dda2..73d80d9b5 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -92,7 +92,6 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } if flags.Write { if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil { - fsmetric.GoferOpensWX.Increment() metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") log.Warningf("Opened a writable executable: %q", name) } @@ -238,10 +237,20 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I // and availability of a host-mappable FD. if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) { n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset) - } else if f.inodeOperations.fileState.hostMappable != nil { - n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset) } else { - n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) + uattr, e := f.UnstableAttr(ctx, file) + if e != nil { + return 0, e + } + if f.inodeOperations.fileState.hostMappable != nil { + n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset, uattr) + } else { + n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) + if n > 0 && uattr.Perms.HasSetUIDOrGID() { + uattr.Perms.DropSetUIDAndMaybeGID() + f.inodeOperations.SetPermissions(ctx, file.Dirent.Inode, uattr.Perms) + } + } } if n == 0 { diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index b97635ec4..da3178527 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -600,11 +600,25 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length if i.session().cachePolicy.useCachingInodeOps(inode) { return i.cachingInodeOps.Truncate(ctx, inode, length) } + + uattr, err := i.fileState.unstableAttr(ctx) + if err != nil { + return err + } + if i.session().cachePolicy == cacheRemoteRevalidating { - return i.fileState.hostMappable.Truncate(ctx, length) + return i.fileState.hostMappable.Truncate(ctx, length, uattr) + } + + mask := p9.SetAttrMask{Size: true} + attr := p9.SetAttr{Size: uint64(length)} + if uattr.Perms.HasSetUIDOrGID() { + mask.Permissions = true + uattr.Perms.DropSetUIDAndMaybeGID() + attr.Permissions = p9.FileMode(uattr.Perms.LinuxMode()) } - return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)}) + return i.fileState.file.setAttr(ctx, mask, attr) } // GetXattr implements fs.InodeOperations.GetXattr. diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 6b3627813..1a6f353d0 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/device" @@ -66,7 +67,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string // Get a p9.File for name. qids, newFile, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name}) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { if cp.cacheNegativeDirents() { // Return a negative Dirent. It will stay cached until something // is created over it. @@ -130,7 +131,16 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string panic(fmt.Sprintf("Create called with unknown or unset open flags: %v", flags)) } + // If the parent directory has setgid enabled, change the new file's owner. owner := fs.FileOwnerFromContext(ctx) + parentUattr, err := dir.UnstableAttr(ctx) + if err != nil { + return nil, err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + } + hostFile, err := newFile.create(ctx, name, openFlags, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)) if err != nil { // Could not create the file. @@ -225,7 +235,18 @@ func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s return syserror.ENAMETOOLONG } + // If the parent directory has setgid enabled, change the new directory's + // owner and enable setgid. owner := fs.FileOwnerFromContext(ctx) + parentUattr, err := dir.UnstableAttr(ctx) + if err != nil { + return err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + perm.SetGID = true + } + if _, err := i.fileState.file.mkdir(ctx, s, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil { return err } @@ -278,7 +299,7 @@ func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name st // N.B. FIFOs use major/minor numbers 0. if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil { - if i.session().overrides == nil || err != syserror.EPERM { + if i.session().overrides == nil || !linuxerr.Equals(linuxerr.EPERM, err) { return err } // If gofer doesn't support mknod, check if we can create an internal fifo. diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 3c45f6cc5..24fc6305c 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -28,9 +28,9 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", - "//pkg/iovec", "//pkg/log", "//pkg/marshal/primitive", "//pkg/refs", @@ -40,6 +40,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/hostfd", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 46a2dc47d..225244868 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/refs" @@ -213,7 +214,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // block (and only for stream sockets). err = syserror.EAGAIN } - if n > 0 && err != syserror.EAGAIN { + if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) { // The caller may need to block to send more data, but // otherwise there isn't anything that can be done about an // error with a partial write. diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go index 7380d75e7..fd48aff11 100644 --- a/pkg/sentry/fs/host/socket_iovec.go +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -16,7 +16,7 @@ package host import ( "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/iovec" + "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/syserror" ) @@ -72,7 +72,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > iovec.MaxIovs { + if iovsRequired > hostfd.MaxSendRecvMsgIov { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 1183727ab..2ff520100 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -191,7 +192,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from // tty_check_change() to -ENOTTY. - if err == syserror.EIO { + if linuxerr.Equals(linuxerr.EIO, err) { return 0, syserror.ENOTTY } return 0, err @@ -211,7 +212,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO // pgID must be non-negative. if pgID < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Process group with pgID must exist in this PID namespace. diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go index ab74724a3..e7db79189 100644 --- a/pkg/sentry/fs/host/util.go +++ b/pkg/sentry/fs/host/util.go @@ -19,12 +19,12 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/syserror" ) func nodeType(s *unix.Stat_t) fs.InodeType { @@ -98,7 +98,7 @@ type dirInfo struct { // isBlockError unwraps os errors and checks if they are caused by EAGAIN or // EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { - if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK { + if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) { return true } if pe, ok := err.(*os.PathError); ok { diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index e97afc626..bd1125dcc 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" @@ -71,7 +72,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // A file could have been created over a whiteout, so we need to // check if something exists in the upper file system first. child, err := parent.upper.Lookup(ctx, name) - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { // We encountered an error that an overlay cannot handle, // we must propagate it to the caller. parent.copyMu.RUnlock() @@ -125,7 +126,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name // Check the lower file system. child, err := parent.lower.Lookup(ctx, name) // Same song and dance as above. - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { // Don't leak resources. if upperInode != nil { upperInode.DecRef(ctx) @@ -396,7 +397,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena // newName has been removed out from under us. That's fine; // filesystems where that can happen must handle stale // 'replaced'. - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } if err == nil { diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go index aa9851b26..cc5ffa6f1 100644 --- a/pkg/sentry/fs/inode_overlay_test.go +++ b/pkg/sentry/fs/inode_overlay_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" @@ -191,11 +192,11 @@ func TestLookup(t *testing.T) { } { t.Run(test.desc, func(t *testing.T) { dirent, err := test.dir.Lookup(ctx, test.name) - if test.found && (err == syserror.ENOENT || dirent.IsNegative()) { + if test.found && (linuxerr.Equals(linuxerr.ENOENT, err) || dirent.IsNegative()) { t.Fatalf("lookup %q expected to find positive dirent, got dirent %v err %v", test.name, dirent, err) } if !test.found { - if err != syserror.ENOENT && !dirent.IsNegative() { + if !linuxerr.Equals(linuxerr.ENOENT, err) && !dirent.IsNegative() { t.Errorf("lookup %q expected to return ENOENT or negative dirent, got dirent %v err %v", test.name, dirent, err) } // Nothing more to check. diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index 1b83643db..4e07043c7 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -132,7 +133,7 @@ func (*Inotify) Write(context.Context, *File, usermem.IOSequence, int64) (int64, // Read implements FileOperations.Read. func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ int64) (int64, error) { if dst.NumBytes() < inotifyEventBaseSize { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } i.evMu.Lock() @@ -156,7 +157,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i // write some events out. return writeLen, nil } - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Linux always dequeues an available event as long as there's enough @@ -183,7 +184,7 @@ func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, // Fsync implements FileOperations.Fsync. func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { - return syserror.EINVAL + return linuxerr.EINVAL } // ReadFrom implements FileOperations.ReadFrom. @@ -329,7 +330,7 @@ func (i *Inotify) RmWatch(ctx context.Context, wd int32) error { watch, ok := i.watches[wd] if !ok { i.mu.Unlock() - return syserror.EINVAL + return linuxerr.EINVAL } // Remove the watch from this instance. diff --git a/pkg/sentry/fs/mock.go b/pkg/sentry/fs/mock.go index 1d6ea5736..2a54c1242 100644 --- a/pkg/sentry/fs/mock.go +++ b/pkg/sentry/fs/mock.go @@ -16,6 +16,7 @@ package fs import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -109,7 +110,7 @@ func (n *MockInodeOperations) SetPermissions(context.Context, *Inode, FilePermis // SetOwner implements fs.InodeOperations.SetOwner. func (*MockInodeOperations) SetOwner(context.Context, *Inode, FileOwner) error { - return syserror.EINVAL + return linuxerr.EINVAL } // SetTimestamps implements fs.InodeOperations.SetTimestamps. diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index 243098a09..340441974 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" @@ -357,7 +358,7 @@ func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly orig, ok := mns.mounts[node] if !ok { // node is not a mount point. - return syserror.EINVAL + return linuxerr.EINVAL } if orig.previous == nil { diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index f96f5a3e5..7e72e47b5 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -19,11 +19,11 @@ import ( "strings" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // The virtual filesystem implements an overlay configuration. For a high-level @@ -218,7 +218,7 @@ func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExist // We don't support copying up from character devices, // named pipes, or anything weird (like proc files). log.Warningf("%s not supported in lower filesytem", lower.StableAttr.Type) - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } return &overlayEntry{ diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 7af7e0b45..e6d74b949 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -30,6 +30,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go index 24426b225..379429ab2 100644 --- a/pkg/sentry/fs/proc/exec_args.go +++ b/pkg/sentry/fs/proc/exec_args.go @@ -21,11 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -104,7 +104,7 @@ var _ fs.FileOperations = (*execArgFile)(nil) // Read reads the exec arg from the process's address space.. func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } m, err := getTaskMM(f.t) diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 91c35eea9..187e9a921 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -34,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -291,7 +291,7 @@ func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s continue } if err := n.s.Statistics(stat, line.prefix); err != nil { - if err == syserror.EOPNOTSUPP { + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) } else { log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 2f2a9f920..546b57287 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -21,6 +21,7 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" @@ -130,7 +131,7 @@ func (s *self) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { } // Who is reading this link? - return "", syserror.EINVAL + return "", linuxerr.EINVAL } // threadSelf is more magical than "self" link. @@ -154,7 +155,7 @@ func (s *threadSelf) Readlink(ctx context.Context, inode *fs.Inode) (string, err } // Who is reading this link? - return "", syserror.EINVAL + return "", linuxerr.EINVAL } // Lookup loads an Inode at name into a Dirent. diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index b998fb75d..085aa6d61 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -77,6 +77,27 @@ func (*overcommitMemory) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandl }, 0 } +// +stateify savable +type maxMapCount struct{} + +// NeedsUpdate implements seqfile.SeqSource. +func (*maxMapCount) NeedsUpdate(int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource. +func (*maxMapCount) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + return []seqfile.SeqData{ + { + Buf: []byte("2147483647\n"), + Handle: (*maxMapCount)(nil), + }, + }, 0 +} + func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { h := hostname{ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), @@ -96,6 +117,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { children := map[string]*fs.Inode{ + "max_map_count": seqfile.NewSeqFileInode(ctx, &maxMapCount{}, msrc), "mmap_min_addr": seqfile.NewSeqFileInode(ctx, &mmapMinAddrData{p.k}, msrc), "overcommit_memory": seqfile.NewSeqFileInode(ctx, &overcommitMemory{}, msrc), } diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 4893af56b..71f37d582 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -28,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -592,7 +592,7 @@ func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSe // Port numbers must be uint16s. if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if err := pf.inode.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index ae5ed25f9..7ece1377a 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -867,7 +868,7 @@ var _ fs.FileOperations = (*commFile)(nil) // Read implements fs.FileOperations.Read. func (f *commFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } buf := []byte(f.t.Name() + "\n") @@ -922,7 +923,7 @@ type auxvecFile struct { // Read implements fs.FileOperations.Read. func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } m, err := getTaskMM(f.t) diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go index 30d5ad4cf..fcdc1e7bd 100644 --- a/pkg/sentry/fs/proc/uid_gid_map.go +++ b/pkg/sentry/fs/proc/uid_gid_map.go @@ -21,12 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -108,7 +108,7 @@ const maxIDMapLines = 5 // Read implements fs.FileOperations.Read. func (imfo *idMapFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } var entries []auth.IDMapEntry if imfo.iops.gids { @@ -134,7 +134,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u // the file ..." - user_namespaces(7) srclen := src.NumBytes() if srclen >= hostarch.PageSize || offset != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } b := make([]byte, srclen) if _, err := src.CopyIn(ctx, b); err != nil { @@ -154,7 +154,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u } lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1) if len(lines) > maxIDMapLines { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } entries := make([]auth.IDMapEntry, len(lines)) @@ -162,7 +162,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u var e auth.IDMapEntry _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length) if err != nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } entries[i] = e } diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go index c0f6fb802..ac896f963 100644 --- a/pkg/sentry/fs/proc/uptime.go +++ b/pkg/sentry/fs/proc/uptime.go @@ -20,10 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -74,7 +74,7 @@ type uptimeFile struct { // Read implements fs.FileOperations.Read. func (f *uptimeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } now := ktime.NowFromContext(ctx) diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index 33da82868..ca9f645f6 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -139,7 +140,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, // Attempt to do a WriteTo; this is likely the most efficient. n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) - if n == 0 && err == syserror.ENOSYS && !opts.Dup { + if n == 0 && linuxerr.Equals(linuxerr.ENOSYS, err) && !opts.Dup { // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be // more efficient than a copy if buffers are cached or readily // available. (It's unlikely that they can actually be donated). @@ -151,7 +152,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, // if we block at some point, we could lose data. If the source is // not a pipe then reading is not destructive; if the destination // is a regular file, then it is guaranteed not to block writing. - if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) { + if n == 0 && linuxerr.Equals(linuxerr.ENOSYS, err) && !opts.Dup && (!dstPipe || !srcPipe) { // Fallback to an in-kernel copy. n, err = io.Copy(w, &io.LimitedReader{ R: r, diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index c7977a217..0148b33cf 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -8,6 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index c8ebe256c..093a14c1f 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" @@ -121,7 +122,7 @@ func (t *TimerOperations) EventUnregister(e *waiter.Entry) { func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { const sizeofUint64 = 8 if dst.NumBytes() < sizeofUint64 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if val := atomic.SwapUint64(&t.val, 0); val != 0 { var buf [sizeofUint64]byte @@ -138,7 +139,7 @@ func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.I // Write implements fs.FileOperations.Write. func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Notify implements ktime.TimerListener.Notify. diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 90398376a..c36a20afe 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/safemem", "//pkg/sentry/device", diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go index bc117ca6a..b48d475ed 100644 --- a/pkg/sentry/fs/tmpfs/fs.go +++ b/pkg/sentry/fs/tmpfs/fs.go @@ -151,5 +151,5 @@ func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSou } // Construct the tmpfs root. - return NewDir(ctx, nil, owner, perms, msrc), nil + return NewDir(ctx, nil, owner, perms, msrc, nil /* parent */) } diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index f4de8c968..ce6be6386 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -226,6 +227,12 @@ func (f *fileInodeOperations) Truncate(ctx context.Context, _ *fs.Inode, size in now := ktime.NowFromContext(ctx) f.attr.ModificationTime = now f.attr.StatusChangeTime = now + + // Truncating clears privilege bits. + f.attr.Perms.SetUID = false + if f.attr.Perms.Group.Execute { + f.attr.Perms.SetGID = false + } } f.dataMu.Unlock() @@ -363,7 +370,14 @@ func (f *fileInodeOperations) write(ctx context.Context, src usermem.IOSequence, now := ktime.NowFromContext(ctx) f.attr.ModificationTime = now f.attr.StatusChangeTime = now - return src.CopyInTo(ctx, &fileReadWriter{f, offset}) + nwritten, err := src.CopyInTo(ctx, &fileReadWriter{f, offset}) + + // Writing clears privilege bits. + if nwritten > 0 { + f.attr.Perms.DropSetUIDAndMaybeGID() + } + + return nwritten, err } type fileReadWriter struct { @@ -442,7 +456,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes())) if end == math.MaxInt64 { // Overflow. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Check if seals prevent either file growth or all writes. @@ -642,7 +656,7 @@ func GetSeals(inode *fs.Inode) (uint32, error) { return f.seals, nil } // Not a memfd inode. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // AddSeals adds new file seals to a memfd inode. @@ -670,5 +684,5 @@ func AddSeals(inode *fs.Inode, val uint32) error { return nil } // Not a memfd inode. - return syserror.EINVAL + return linuxerr.EINVAL } diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 577052888..6aa8ff331 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -87,7 +87,20 @@ type Dir struct { var _ fs.InodeOperations = (*Dir)(nil) // NewDir returns a new directory. -func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { +func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource, parent *fs.Inode) (*fs.Inode, error) { + // If the parent has setgid enabled, the new directory enables it and changes + // its GID. + if parent != nil { + parentUattr, err := parent.UnstableAttr(ctx) + if err != nil { + return nil, err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + perms.SetGID = true + } + } + d := &Dir{ ramfsDir: ramfs.NewDir(ctx, contents, owner, perms), kernel: kernel.KernelFromContext(ctx), @@ -101,7 +114,7 @@ func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwn InodeID: tmpfsDevice.NextIno(), BlockSize: hostarch.PageSize, Type: fs.Directory, - }) + }), nil } // afterLoad is invoked by stateify. @@ -219,11 +232,21 @@ func (d *Dir) SetTimestamps(ctx context.Context, i *fs.Inode, ts fs.TimeSpec) er func (d *Dir) newCreateOps() *ramfs.CreateOps { return &ramfs.CreateOps{ NewDir: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) { - return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil + return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource, dir) }, NewFile: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) { + // If the parent has setgid enabled, change the GID of the new file. + owner := fs.FileOwnerFromContext(ctx) + parentUattr, err := dir.UnstableAttr(ctx) + if err != nil { + return nil, err + } + if parentUattr.Perms.SetGID { + owner.GID = parentUattr.Owner.GID + } + uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{ - Owner: fs.FileOwnerFromContext(ctx), + Owner: owner, Perms: perms, // Always start unlinked. Links: 0, diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 86ada820e..5933cb67b 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -17,6 +17,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal/primitive", "//pkg/refs", diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go index 13f4901db..0e5916380 100644 --- a/pkg/sentry/fs/tty/fs.go +++ b/pkg/sentry/fs/tty/fs.go @@ -16,9 +16,9 @@ package tty import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/syserror" ) // ptsDevice is the pseudo-filesystem device. @@ -64,7 +64,7 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou // No options are supported. if data != "" { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } return newDir(ctx, fs.NewMountSource(ctx, &superOperations{}, f, flags)), nil diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD index 66e949c95..4acc73ee0 100644 --- a/pkg/sentry/fs/user/BUILD +++ b/pkg/sentry/fs/user/BUILD @@ -12,6 +12,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/log", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go index 124bc95ed..f6eaab2bd 100644 --- a/pkg/sentry/fs/user/path.go +++ b/pkg/sentry/fs/user/path.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -93,7 +94,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s binPath := path.Join(p, name) traversals := uint(linux.MaxSymlinkTraversals) d, err := mns.FindInode(ctx, root, nil, binPath, &traversals) - if err == syserror.ENOENT || err == syserror.EACCES { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.EACCES, err) { // Didn't find it here. continue } @@ -142,7 +143,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam Flags: linux.O_RDONLY, } dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts) - if err == syserror.ENOENT || err == syserror.EACCES { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.EACCES, err) { // Didn't find it here. continue } diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go index 12b786224..7f8fa8038 100644 --- a/pkg/sentry/fs/user/user_test.go +++ b/pkg/sentry/fs/user/user_test.go @@ -104,7 +104,10 @@ func TestGetExecUserHome(t *testing.T) { t.Run(name, func(t *testing.T) { ctx := contexttest.Context(t) msrc := fs.NewPseudoMountSource(ctx) - rootInode := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc) + rootInode, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */) + if err != nil { + t.Fatalf("tmpfs.NewDir failed: %v", err) + } mns, err := fs.NewMountNamespace(ctx, rootInode) if err != nil { diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD index 37efb641a..4c9c5b344 100644 --- a/pkg/sentry/fsimpl/cgroupfs/BUILD +++ b/pkg/sentry/fsimpl/cgroupfs/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/coverage", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go index 6512e9cdb..4290ffe0d 100644 --- a/pkg/sentry/fsimpl/cgroupfs/base.go +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -23,10 +23,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -133,6 +133,17 @@ func (c *cgroupInode) Controllers() []kernel.CgroupController { return c.fs.kcontrollers } +// tasks returns a snapshot of the tasks inside the cgroup. +func (c *cgroupInode) tasks() []*kernel.Task { + c.fs.tasksMu.RLock() + defer c.fs.tasksMu.RUnlock() + ts := make([]*kernel.Task, 0, len(c.ts)) + for t := range c.ts { + ts = append(ts, t) + } + return ts +} + // Enter implements kernel.CgroupImpl.Enter. func (c *cgroupInode) Enter(t *kernel.Task) { c.fs.tasksMu.Lock() @@ -163,10 +174,7 @@ func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error pgids := make(map[kernel.ThreadID]struct{}) - d.fs.tasksMu.RLock() - defer d.fs.tasksMu.RUnlock() - - for task := range d.ts { + for _, task := range d.tasks() { // Map dedups pgid, since iterating over all tasks produces multiple // entries for the group leaders. if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 { @@ -205,10 +213,7 @@ func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error { var pids []kernel.ThreadID - d.fs.tasksMu.RLock() - defer d.fs.tasksMu.RUnlock() - - for task := range d.ts { + for _, task := range d.tasks() { if pid := currPidns.IDOfTask(task); pid != 0 { pids = append(pids, pid) } @@ -248,7 +253,7 @@ func parseInt64FromString(ctx context.Context, src usermem.IOSequence, offset in // Note: This also handles zero-len writes if offset is beyond the end // of src, or src is empty. ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err) - return 0, int64(n), syserror.EINVAL + return 0, int64(n), linuxerr.EINVAL } return val, int64(n), nil diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index 54050de3c..b5883cbd2 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -49,8 +49,9 @@ // // kernel.CgroupRegistry.mu // cgroupfs.filesystem.mu -// Task.mu -// cgroupfs.filesystem.tasksMu. +// kernel.TaskSet.mu +// kernel.Task.mu +// cgroupfs.filesystem.tasksMu. package cgroupfs import ( @@ -61,6 +62,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -166,7 +168,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt maxCachedDentries, err = strconv.ParseUint(str, 10, 64) if err != nil { ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } } @@ -194,7 +196,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if _, ok := mopts["all"]; ok { if len(wantControllers) > 0 { ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } delete(mopts, "all") @@ -208,7 +210,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if len(mopts) != 0 { ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } k := kernel.KernelFromContext(ctx) diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 6af3c3781..50b4c02ef 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index e75954105..7a488e9fd 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -56,7 +57,7 @@ func (*FilesystemType) Name() string { func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { // No data allowed. if opts.Data != "" { - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fstype.initOnce.Do(func() { diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 93c031c89..1374fd3be 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -17,6 +17,7 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -80,7 +81,7 @@ func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs // SetStat implements kernfs.Inode.SetStat func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask&linux.STATX_SIZE != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) } diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go index 96d2054cb..81572b991 100644 --- a/pkg/sentry/fsimpl/devpts/replica.go +++ b/pkg/sentry/fsimpl/devpts/replica.go @@ -17,6 +17,7 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -92,7 +93,7 @@ func (ri *replicaInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vf // SetStat implements kernfs.Inode.SetStat func (ri *replicaInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask&linux.STATX_SIZE != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } return ri.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) } diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 2dbc6bfd5..5e8b464a0 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -47,6 +47,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fspath", "//pkg/log", @@ -88,13 +89,13 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/marshal/primitive", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/test/testutil", "//pkg/usermem", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index 1165234f9..79719faed 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -18,6 +18,7 @@ import ( "io" "math" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/syserror" ) @@ -84,7 +85,7 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) { } if off < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } offset := uint64(off) diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 512b70ede..cc067c20e 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -17,12 +17,12 @@ package ext import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // directory represents a directory inode. It holds the childList in memory. @@ -218,7 +218,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { if whence != linux.SEEK_SET && whence != linux.SEEK_CUR { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } dir := fd.inode().impl.(*directory) @@ -234,7 +234,7 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in if offset < 0 { // lseek(2) specifies that EINVAL should be returned if the resulting offset // is negative. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } n := int64(len(dir.childMap)) diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index 38fb7962b..80854b501 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -22,12 +22,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) // Name is the name of this filesystem. @@ -133,13 +133,13 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // mount(2) specifies that EINVAL should be returned if the superblock is // invalid. fs.vfsfs.DecRef(ctx) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // Refuse to mount if the filesystem is incompatible. if !isCompatible(fs.sb) { fs.vfsfs.DecRef(ctx) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fs.bgs, err = readBlockGroups(dev, fs.sb) diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index d9fd4590c..db712e71f 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -26,12 +26,12 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/pkg/usermem" ) @@ -173,7 +173,7 @@ func TestSeek(t *testing.T) { } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("expected error EINVAL but got %v", err) } @@ -187,7 +187,7 @@ func TestSeek(t *testing.T) { } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("expected error EINVAL but got %v", err) } @@ -204,7 +204,7 @@ func TestSeek(t *testing.T) { } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("expected error EINVAL but got %v", err) } } diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index 778460107..f449bc8bd 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -18,6 +18,7 @@ import ( "io" "sort" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) @@ -65,7 +66,7 @@ func (f *extentFile) buildExtTree() error { if f.root.Header.NumEntries > 4 { // read(2) specifies that EINVAL should be returned if the file is unsuitable // for reading. - return syserror.EINVAL + return linuxerr.EINVAL } f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries) @@ -145,7 +146,7 @@ func (f *extentFile) ReadAt(dst []byte, off int64) (int, error) { } if off < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if uint64(off) >= f.regFile.inode.diskInode.Size() { diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index d4fc484a2..1d2eaa0d4 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -344,7 +345,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st } symlink, ok := inode.impl.(*symlink) if !ok { - return "", syserror.EINVAL + return "", linuxerr.EINVAL } return symlink.target, nil } diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 4a555bf72..b3df2337f 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -147,7 +148,7 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) { return &f.inode, nil default: // TODO(b/134676337): Return appropriate errors for sockets, pipes and devices. - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 5ad9befcd..9a094716a 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -139,10 +140,10 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( case linux.SEEK_END: offset += int64(fd.inode().diskInode.Size()) default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } fd.off = offset return offset, nil diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 3a4777fbe..871df5984 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -46,6 +46,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/marshal", @@ -76,6 +77,7 @@ go_test( library = ":fuse", deps = [ "//pkg/abi/linux", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal", "//pkg/sentry/fsimpl/testutil", diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 78ea6a31e..1fddd858e 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -19,9 +19,9 @@ import ( "testing" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" ) // TestConnectionInitBlock tests if initialization @@ -104,7 +104,7 @@ func TestConnectionAbort(t *testing.T) { // After abort, Call() should return directly with ENOTCONN. req := conn.NewRequest(creds, 0, 0, 0, testObj) _, err = conn.Call(task, req) - if err != syserror.ENOTCONN { + if !linuxerr.Equals(linuxerr.ENOTCONN, err) { t.Fatalf("Incorrect error code received for Call() after connection aborted") } diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index 5d2bae14e..0d0eed543 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -149,7 +150,7 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R // If the read buffer is too small, error out. if dst.NumBytes() < int64(minBuffSize) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } fd.mu.Lock() @@ -293,7 +294,7 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt // Assert that the header isn't read into the writeBuf yet. if fd.writeCursor >= hdrLen { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // We don't have the full common response header yet. @@ -322,7 +323,7 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt if !ok { // Server sent us a response for a request we never sent, // or for which we already received a reply (e.g. aborted), an unlikely event. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } delete(fd.completions, hdr.Unique) @@ -434,7 +435,7 @@ func (fd *DeviceFD) sendError(ctx context.Context, errno int32, unique linux.FUS if !ok { // A response for a request we never sent, // or for which we already received a reply (e.g. aborted). - return syserror.EINVAL + return linuxerr.EINVAL } delete(fd.completions, respHdr.Unique) diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 167c899e2..be5bcd6af 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -121,30 +122,30 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt deviceDescriptorStr, ok := mopts["fd"] if !ok { ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option fd missing") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } delete(mopts, "fd") deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) if err != nil { ctx.Debugf("fusefs.FilesystemType.GetFilesystem: invalid fd: %q (%v)", deviceDescriptorStr, err) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } kernelTask := kernel.TaskFromContext(ctx) if kernelTask == nil { log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name()) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor)) if fuseFDGeneric == nil { - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } defer fuseFDGeneric.DecRef(ctx) fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD) if !ok { log.Warningf("%s.GetFilesystem: device FD is %T, not a FUSE device", fsType.Name, fuseFDGeneric) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // Parse and set all the other supported FUSE mount options. @@ -154,17 +155,17 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt uid, err := strconv.ParseUint(uidStr, 10, 32) if err != nil { log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), uidStr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } kuid := creds.UserNamespace.MapToKUID(auth.UID(uid)) if !kuid.Ok() { ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped uid: %d", uid) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fsopts.uid = kuid } else { ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option user_id missing") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } if gidStr, ok := mopts["group_id"]; ok { @@ -172,17 +173,17 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt gid, err := strconv.ParseUint(gidStr, 10, 32) if err != nil { log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), gidStr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } kgid := creds.UserNamespace.MapToKGID(auth.GID(gid)) if !kgid.Ok() { ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped gid: %d", gid) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fsopts.gid = kgid } else { ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option group_id missing") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } if modeStr, ok := mopts["rootmode"]; ok { @@ -190,12 +191,12 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mode, err := strconv.ParseUint(modeStr, 8, 32) if err != nil { log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fsopts.rootMode = linux.FileMode(mode) } else { ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option rootmode missing") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // Set the maxInFlightRequests option. @@ -206,7 +207,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt maxRead, err := strconv.ParseUint(maxReadStr, 10, 32) if err != nil { log.Warningf("%s.GetFilesystem: invalid max_read: max_read=%s", fsType.Name(), maxReadStr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } if maxRead < fuseMinMaxRead { maxRead = fuseMinMaxRead @@ -229,7 +230,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Check for unparsed options. if len(mopts) != 0 { log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // Create a new FUSE filesystem. @@ -258,7 +259,7 @@ func newFUSEFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, fsTyp conn, err := newFUSEConnection(ctx, fuseFD, opts) if err != nil { log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } fs := &filesystem{ @@ -418,7 +419,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr kernelTask := kernel.TaskFromContext(ctx) if kernelTask == nil { log.Warningf("fusefs.Inode.Open: couldn't get kernel task from context") - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } // Build the request. @@ -440,7 +441,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr if err != nil { return nil, err } - if err := res.Error(); err == syserror.ENOSYS && !isDir { + if err := res.Error(); linuxerr.Equals(linuxerr.ENOSYS, err) && !isDir { i.fs.conn.noOpen = true } else if err != nil { return nil, err @@ -512,7 +513,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) kernelTask := kernel.TaskFromContext(ctx) if kernelTask == nil { log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID) - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } in := linux.FUSECreateIn{ CreateMeta: linux.FUSECreateMeta{ @@ -552,7 +553,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err kernelTask := kernel.TaskFromContext(ctx) if kernelTask == nil { log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) - return syserror.EINVAL + return linuxerr.EINVAL } in := linux.FUSEUnlinkIn{Name: name} req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) @@ -596,7 +597,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo kernelTask := kernel.TaskFromContext(ctx) if kernelTask == nil { log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) res, err := i.fs.conn.Call(kernelTask, req) @@ -626,13 +627,13 @@ func (i *inode) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, // Readlink implements kernfs.Inode.Readlink. func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { if i.Mode().FileType()&linux.S_IFLNK == 0 { - return "", syserror.EINVAL + return "", linuxerr.EINVAL } if len(i.link) == 0 { kernelTask := kernel.TaskFromContext(ctx) if kernelTask == nil { log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context") - return "", syserror.EINVAL + return "", linuxerr.EINVAL } req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) res, err := i.fs.conn.Call(kernelTask, req) @@ -728,7 +729,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp task := kernel.TaskFromContext(ctx) if task == nil { log.Warningf("couldn't get kernel task from context") - return linux.FUSEAttr{}, syserror.EINVAL + return linux.FUSEAttr{}, linuxerr.EINVAL } creds := auth.CredentialsFromContext(ctx) @@ -833,7 +834,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre task := kernel.TaskFromContext(ctx) if task == nil { log.Warningf("couldn't get kernel task from context") - return syserror.EINVAL + return linuxerr.EINVAL } // We should retain the original file type when assigning new mode. diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 66ea889f9..35d0ab6f4 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -39,7 +40,7 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui t := kernel.TaskFromContext(ctx) if t == nil { log.Warningf("fusefs.Read: couldn't get kernel task from context") - return nil, 0, syserror.EINVAL + return nil, 0, linuxerr.EINVAL } // Round up to a multiple of page size. @@ -155,7 +156,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, t := kernel.TaskFromContext(ctx) if t == nil { log.Warningf("fusefs.Read: couldn't get kernel task from context") - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // One request cannnot exceed either maxWrite or maxPages. diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go index 5bdd096c3..a0802cd32 100644 --- a/pkg/sentry/fsimpl/fuse/regular_file.go +++ b/pkg/sentry/fsimpl/fuse/regular_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -39,7 +40,7 @@ type regularFileFD struct { // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Check that flags are supported. @@ -56,7 +57,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs } else if size > math.MaxUint32 { // FUSE only supports uint32 for size. // Overflow. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // TODO(gvisor.dev/issue/3678): Add direct IO support. @@ -143,7 +144,7 @@ func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts // final offset should be ignored by PWrite. func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, offset, syserror.EINVAL + return 0, offset, linuxerr.EINVAL } // Check that flags are supported. @@ -171,11 +172,11 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off if srclen > math.MaxUint32 { // FUSE only supports uint32 for size. // Overflow. - return 0, offset, syserror.EINVAL + return 0, offset, linuxerr.EINVAL } if end := offset + srclen; end < offset { // Overflow. - return 0, offset, syserror.EINVAL + return 0, offset, linuxerr.EINVAL } srclen, err = vfs.CheckLimit(ctx, offset, srclen) diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 368272f12..752060044 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -49,6 +49,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 177e42649..5c48a9fee 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refsvfs2" @@ -28,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) func (d *dentry) isDir() bool { @@ -297,7 +297,7 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in switch whence { case linux.SEEK_SET: if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset == 0 { // Ensure that the next call to fd.IterDirents() calls @@ -309,13 +309,13 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in case linux.SEEK_CUR: offset += fd.off if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Don't clear fd.dirents in this case, even if offset == 0. fd.off = offset return fd.off, nil default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 91ec4a142..067b7aac1 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" @@ -255,7 +256,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { parent.cacheNegativeLookupLocked(name) } return nil, err @@ -382,7 +383,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.EEXIST } checkExistence := func() error { - if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT { + if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } else if child != nil { return syserror.EEXIST @@ -469,7 +470,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b name := rp.Component() if dir { if name == "." { - return syserror.EINVAL + return linuxerr.EINVAL } if name == ".." { return syserror.ENOTEMPTY @@ -715,7 +716,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v mode |= linux.S_ISGID } if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil { - if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { + if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) { return err } ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err) @@ -752,7 +753,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { creds := rp.Credentials() _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - if err != syserror.EPERM { + if !linuxerr.Equals(linuxerr.EPERM, err) { return err } @@ -765,7 +766,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v case err == nil: // Step succeeded, another file exists. return syserror.EEXIST - case err != syserror.ENOENT: + case !linuxerr.Equals(linuxerr.ENOENT, err): // Unexpected error. return err } @@ -862,7 +863,7 @@ afterTrailingSymlink: // Determine whether or not we need to create a file. parent.dirMu.Lock() child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err == syserror.ENOENT && mayCreate { + if linuxerr.Equals(linuxerr.ENOENT, err) && mayCreate { if parent.isSynthetic() { parent.dirMu.Unlock() return nil, syserror.EPERM @@ -942,7 +943,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open return nil, syserror.EISDIR } if opts.Flags&linux.O_DIRECT != 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } if !d.isSynthetic() { if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil { @@ -998,7 +999,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { if opts.Flags&linux.O_DIRECT != 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } fdObj, err := d.file.connect(ctx, p9.AnonymousSocket) if err != nil { @@ -1019,7 +1020,7 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if opts.Flags&linux.O_DIRECT != 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } // We assume that the server silently inserts O_NONBLOCK in the open flags // for all named pipes (because all existing gofers do this). @@ -1033,7 +1034,7 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs. retry: h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0) if err != nil { - if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO { + if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) { // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails // with ENXIO if opening the same named pipe with O_WRONLY would // block because there are no readers of the pipe. @@ -1187,18 +1188,14 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st return "", err } if !d.isSymlink() { - return "", syserror.EINVAL + return "", linuxerr.EINVAL } return d.readlink(ctx, rp.Mount()) } // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // Requires 9P support. - return syserror.EINVAL - } - + // Resolve newParent first to verify that it's on this Mount. var ds *[]*dentry fs.renameMu.Lock() defer fs.renameMuUnlockAndCheckCaching(ctx, &ds) @@ -1206,8 +1203,21 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return linuxerr.EINVAL + } + if fs.opts.interop == InteropModeShared && opts.Flags&linux.RENAME_NOREPLACE != 0 { + // Requires 9P support to synchronize with other remote filesystem + // users. + return linuxerr.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -1251,7 +1261,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } if renamed.isDir() { if renamed == newParent || genericIsAncestorDentry(renamed, newParent) { - return syserror.EINVAL + return linuxerr.EINVAL } if oldParent != newParent { if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil { @@ -1275,11 +1285,14 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return syserror.ENOENT } replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } var replacedVFSD *vfs.Dentry if replaced != nil { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 21692d2ac..c7ebd435c 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -46,6 +46,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" @@ -318,7 +319,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mfp := pgalloc.MemoryFileProviderFromContext(ctx) if mfp == nil { ctx.Warningf("gofer.FilesystemType.GetFilesystem: context does not provide a pgalloc.MemoryFileProvider") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } mopts := vfs.GenericParseMountOptions(opts.Data) @@ -354,7 +355,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.interop = InteropModeShared default: ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: %s=%s", moptCache, cache) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } } @@ -365,7 +366,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32) if err != nil { ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltUID, dfltuidstr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // In Linux, dfltuid is interpreted as a UID and is converted to a KUID // in the caller's user namespace, but goferfs isn't @@ -378,7 +379,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32) if err != nil { ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltGID, dfltgidstr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fsopts.dfltgid = auth.KGID(dfltgid) } @@ -390,7 +391,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt msize, err := strconv.ParseUint(msizestr, 10, 32) if err != nil { ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: %s=%s", moptMsize, msizestr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fsopts.msize = uint32(msize) } @@ -409,7 +410,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt maxCachedDentries, err := strconv.ParseUint(str, 10, 64) if err != nil { ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } fsopts.maxCachedDentries = maxCachedDentries } @@ -433,14 +434,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Check for unparsed options. if len(mopts) != 0 { ctx.Warningf("gofer.FilesystemType.GetFilesystem: unknown options: %v", mopts) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // Handle internal options. iopts, ok := opts.InternalData.(InternalFilesystemOptions) if opts.InternalData != nil && !ok { ctx.Warningf("gofer.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted gofer.InternalFilesystemOptions", opts.InternalData) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // If !ok, iopts being the zero value is correct. @@ -503,7 +504,7 @@ func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int trans, ok := mopts[moptTransport] if !ok || trans != transportModeFD { ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as '%s=%s'", moptTransport, transportModeFD) - return -1, syserror.EINVAL + return -1, linuxerr.EINVAL } delete(mopts, moptTransport) @@ -511,28 +512,28 @@ func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int rfdstr, ok := mopts[moptReadFD] if !ok { ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as '%s=<file descriptor>'", moptReadFD) - return -1, syserror.EINVAL + return -1, linuxerr.EINVAL } delete(mopts, moptReadFD) rfd, err := strconv.Atoi(rfdstr) if err != nil { ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: %s=%s", moptReadFD, rfdstr) - return -1, syserror.EINVAL + return -1, linuxerr.EINVAL } wfdstr, ok := mopts[moptWriteFD] if !ok { ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as '%s=<file descriptor>'", moptWriteFD) - return -1, syserror.EINVAL + return -1, linuxerr.EINVAL } delete(mopts, moptWriteFD) wfd, err := strconv.Atoi(wfdstr) if err != nil { ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: %s=%s", moptWriteFD, wfdstr) - return -1, syserror.EINVAL + return -1, linuxerr.EINVAL } if rfd != wfd { ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd) - return -1, syserror.EINVAL + return -1, linuxerr.EINVAL } return rfd, nil } @@ -1110,7 +1111,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs case linux.S_IFDIR: return syserror.EISDIR default: - return syserror.EINVAL + return linuxerr.EINVAL } } @@ -1282,9 +1283,12 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) } func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { - // We only support xattrs prefixed with "user." (see b/148380782). Currently, - // there is no need to expose any other xattrs through a gofer. - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + // Deny access to the "security" and "system" namespaces since applications + // may expect these to affect kernel behavior in unimplemented ways + // (b/148380782). Allow all other extended attributes to be passed through + // to the remote filesystem. This is inconsistent with Linux's 9p client, + // but consistent with other filesystems (e.g. FUSE). + if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) { return syserror.EOPNOTSUPP } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) @@ -1684,7 +1688,7 @@ func (d *dentry) setDeleted() { } func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { - if d.file.isNil() || !d.userXattrSupported() { + if d.file.isNil() { return nil, nil } xattrMap, err := d.file.listXattr(ctx, size) @@ -1693,10 +1697,7 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui } xattrs := make([]string, 0, len(xattrMap)) for x := range xattrMap { - // We only support xattrs in the user.* namespace. - if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) { - xattrs = append(xattrs, x) - } + xattrs = append(xattrs, x) } return xattrs, nil } @@ -1731,13 +1732,6 @@ func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name return d.file.removeXattr(ctx, name) } -// Extended attributes in the user.* namespace are only supported for regular -// files and directories. -func (d *dentry) userXattrSupported() bool { - filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType() - return filetype == linux.ModeRegular || filetype == linux.ModeDirectory -} - // Preconditions: // * !d.isSynthetic(). // * d.isRegularFile() || d.isDir(). @@ -1770,7 +1764,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool openReadable := !d.readFile.isNil() || read openWritable := !d.writeFile.isNil() || write h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc) - if err == syserror.EACCES && (openReadable != read || openWritable != write) { + if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) { // It may not be possible to use a single handle for both // reading and writing, since permissions on the file may have // changed to e.g. disallow reading after previously being diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index c7bf10007..398288ee3 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -78,7 +79,7 @@ func nonblockingPipeHasWriter(fd int32) (bool, error) { defer tempPipeMu.Unlock() // Copy 1 byte from fd into the temporary pipe. n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK) - if err == syserror.EAGAIN { + if linuxerr.Equals(linuxerr.EAGAIN, err) { // The pipe represented by fd is empty, but has a writer. return true, nil } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 0a954c138..89eab04cd 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" @@ -60,7 +61,6 @@ func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, return nil, err } if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if atomic.LoadInt32(&d.mmapFD) >= 0 { @@ -125,7 +125,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs }() if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Check that flags are supported. @@ -195,7 +195,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // offset should be ignored by PWrite. func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, offset, syserror.EINVAL + return 0, offset, linuxerr.EINVAL } // Check that flags are supported. @@ -298,7 +298,7 @@ func (fd *regularFileFD) writeCache(ctx context.Context, d *dentry, offset int64 pgstart := hostarch.PageRoundDown(uint64(offset)) pgend, ok := hostarch.PageRoundUp(uint64(offset + src.NumBytes())) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } mr := memmap.MappableRange{pgstart, pgend} var freed []memmap.FileRange @@ -663,10 +663,10 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 offset = size } default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } return offset, nil } @@ -679,28 +679,28 @@ func (fd *regularFileFD) Sync(ctx context.Context) error { // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { d := fd.dentry() - switch d.fs.opts.interop { - case InteropModeExclusive: - // Any mapping is fine. - case InteropModeWritethrough: - // Shared writable mappings require a host FD, since otherwise we can't - // synchronously flush memory-mapped writes to the remote file. - if opts.Private || !opts.MaxPerms.Write { - break - } - fallthrough - case InteropModeShared: - // All mappings require a host FD to be coherent with other filesystem - // users. - if d.fs.opts.forcePageCache { - // Whether or not we have a host FD, we're not allowed to use it. - return syserror.ENODEV - } - if atomic.LoadInt32(&d.mmapFD) < 0 { - return syserror.ENODEV + // Force sentry page caching at your own risk. + if !d.fs.opts.forcePageCache { + switch d.fs.opts.interop { + case InteropModeExclusive: + // Any mapping is fine. + case InteropModeWritethrough: + // Shared writable mappings require a host FD, since otherwise we + // can't synchronously flush memory-mapped writes to the remote + // file. + if opts.Private || !opts.MaxPerms.Write { + break + } + fallthrough + case InteropModeShared: + // All mappings require a host FD to be coherent with other + // filesystem users. + if atomic.LoadInt32(&d.mmapFD) < 0 { + return syserror.ENODEV + } + default: + panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop)) } - default: - panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop)) } // After this point, d may be used as a memmap.Mappable. d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init) @@ -709,12 +709,12 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt } func (d *dentry) mayCachePages() bool { - if d.fs.opts.interop == InteropModeShared { - return false - } if d.fs.opts.forcePageCache { return true } + if d.fs.opts.interop == InteropModeShared { + return false + } return atomic.LoadInt32(&d.mmapFD) >= 0 } diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go index 83e841a51..e67422a2f 100644 --- a/pkg/sentry/fsimpl/gofer/save_restore.go +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -21,13 +21,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) type saveRestoreContextID int @@ -92,7 +92,7 @@ func (fd *specialFileFD) savePipeData(ctx context.Context) error { fd.buf = append(fd.buf, buf[:n]...) } if err != nil { - if err == io.EOF || err == syserror.EAGAIN { + if err == io.EOF || linuxerr.Equals(linuxerr.EAGAIN, err) { break } return err diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index dc019ebd5..2a922d120 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" @@ -101,7 +102,6 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci d.fs.specialFileFDs[fd] = struct{}{} d.fs.syncMu.Unlock() if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if h.fd >= 0 { @@ -184,7 +184,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs }() if fd.seekable && offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Check that flags are supported. @@ -229,7 +229,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // Just buffer the read instead. buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) - if err == syserror.EAGAIN { + if linuxerr.Equals(linuxerr.EAGAIN, err) { err = syserror.ErrWouldBlock } if n == 0 { @@ -264,7 +264,7 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // offset should be ignored by PWrite. func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if fd.seekable && offset < 0 { - return 0, offset, syserror.EINVAL + return 0, offset, linuxerr.EINVAL } // Check that flags are supported. @@ -317,7 +317,7 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off return 0, offset, copyErr } n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset)) - if err == syserror.EAGAIN { + if linuxerr.Equals(linuxerr.EAGAIN, err) { err = syserror.ErrWouldBlock } // Update offset if the offset is valid. diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index b94dfeb7f..476545d00 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -45,10 +45,10 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fdnotifier", "//pkg/fspath", "//pkg/hostarch", - "//pkg/iovec", "//pkg/log", "//pkg/marshal/primitive", "//pkg/refs", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index a81f550b1..4d2b282a0 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -24,6 +24,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" @@ -109,7 +110,7 @@ type inode struct { func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) { // Determine if hostFD is seekable. _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) - seekable := err != syserror.ESPIPE + seekable := !linuxerr.Equals(linuxerr.ESPIPE, err) // We expect regular files to be seekable, as this is required for them to // be memory-mappable. if !seekable && fileType == unix.S_IFREG { @@ -289,10 +290,10 @@ func (i *inode) Mode() linux.FileMode { // Stat implements kernfs.Inode.Stat. func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { - return linux.Statx{}, syserror.EINVAL + return linux.Statx{}, linuxerr.EINVAL } if opts.Sync&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE { - return linux.Statx{}, syserror.EINVAL + return linux.Statx{}, linuxerr.EINVAL } fs := vfsfs.Impl().(*filesystem) @@ -301,7 +302,7 @@ func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOp mask := opts.Mask & linux.STATX_ALL var s unix.Statx_t err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s) - if err == syserror.ENOSYS { + if linuxerr.Equals(linuxerr.ENOSYS, err) { // Fallback to fstat(2), if statx(2) is not supported on the host. // // TODO(b/151263641): Remove fallback. @@ -425,7 +426,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } if m&linux.STATX_SIZE != 0 { if hostStat.Mode&linux.S_IFMT != linux.S_IFREG { - return syserror.EINVAL + return linuxerr.EINVAL } if err := unix.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err @@ -730,7 +731,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i switch whence { case linux.SEEK_SET: if offset < 0 { - return f.offset, syserror.EINVAL + return f.offset, linuxerr.EINVAL } f.offset = offset @@ -740,7 +741,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return f.offset, syserror.EOVERFLOW } if f.offset+offset < 0 { - return f.offset, syserror.EINVAL + return f.offset, linuxerr.EINVAL } f.offset += offset @@ -756,7 +757,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return f.offset, syserror.EOVERFLOW } if size+offset < 0 { - return f.offset, syserror.EINVAL + return f.offset, linuxerr.EINVAL } f.offset = size + offset @@ -773,7 +774,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i default: // Invalid whence. - return f.offset, syserror.EINVAL + return f.offset, linuxerr.EINVAL } return f.offset, nil diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index ca85f5601..8cce36212 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/socket/control" @@ -160,7 +161,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // block (and only for stream sockets). err = syserror.EAGAIN } - if n > 0 && err != syserror.EAGAIN { + if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) { // The caller may need to block to send more data, but // otherwise there isn't anything that can be done about an // error with a partial write. diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go index b123a63ee..e090bb725 100644 --- a/pkg/sentry/fsimpl/host/socket_iovec.go +++ b/pkg/sentry/fsimpl/host/socket_iovec.go @@ -16,7 +16,7 @@ package host import ( "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/iovec" + "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/syserror" ) @@ -70,7 +70,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > iovec.MaxIovs { + if iovsRequired > hostfd.MaxSendRecvMsgIov { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 0f9e20a84..c7bf563f0 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -211,7 +212,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from tty_check_change() // to -ENOTTY. - if err == syserror.EIO { + if linuxerr.Equals(linuxerr.EIO, err) { return 0, syserror.ENOTTY } return 0, err @@ -230,7 +231,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch // pgID must be non-negative. if pgID < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Process group with pgID must exist in this PID namespace. diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 63b465859..95d7ebe2e 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -17,7 +17,7 @@ package host import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/errors/linuxerr" ) func toTimespec(ts linux.StatxTimestamp, omit bool) unix.Timespec { @@ -44,5 +44,5 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { - return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK + return linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) } diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index b7d13cced..d53937db6 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -104,6 +104,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/log", @@ -135,6 +136,7 @@ go_test( ":kernfs", "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index e55111af0..8b008dc10 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -248,10 +249,10 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int panic(fmt.Sprintf("Invalid GenericDirectoryFD.seekEnd = %v", fd.seekEnd)) } default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } fd.off = offset return offset, nil diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index f50b0fb08..1a314f59e 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" @@ -411,7 +412,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v defer rp.Mount().EndWrite() childI, err := parent.inode.NewDir(ctx, pc, opts) if err != nil { - if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { + if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) { return err } childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode) @@ -546,7 +547,7 @@ afterTrailingSymlink: } // Determine whether or not we need to create a file. child, err := fs.stepExistingLocked(ctx, rp, parent, false /* mayFollowSymlinks */) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { // Already checked for searchability above; now check for writability. if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { return nil, err @@ -622,7 +623,7 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st } if !d.isSymlink() { fs.mu.RUnlock() - return "", syserror.EINVAL + return "", linuxerr.EINVAL } // Inode.Readlink() cannot be called holding fs locks. @@ -635,12 +636,6 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - // Only RENAME_NOREPLACE is supported. - if opts.Flags&^linux.RENAME_NOREPLACE != 0 { - return syserror.EINVAL - } - noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 - fs.mu.Lock() defer fs.processDeferredDecRefs(ctx) defer fs.mu.Unlock() @@ -651,6 +646,13 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + // Only RENAME_NOREPLACE is supported. + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return linuxerr.EINVAL + } + noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 + mnt := rp.Mount() if mnt != oldParentVD.Mount() { return syserror.EXDEV @@ -683,10 +685,12 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } return syserror.EBUSY } - switch err := checkCreateLocked(ctx, rp.Credentials(), newName, dstDir); err { - case nil: + + err = checkCreateLocked(ctx, rp.Credentials(), newName, dstDir) + switch { + case err == nil: // Ok, continue with rename as replacement. - case syserror.EEXIST: + case linuxerr.Equals(linuxerr.EEXIST, err): if noReplace { // Won't overwrite existing node since RENAME_NOREPLACE was requested. return syserror.EEXIST diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 3d0866ecf..62872946e 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -158,12 +159,12 @@ type InodeNotSymlink struct{} // Readlink implements Inode.Readlink. func (InodeNotSymlink) Readlink(context.Context, *vfs.Mount) (string, error) { - return "", syserror.EINVAL + return "", linuxerr.EINVAL } // Getlink implements Inode.Getlink. func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, string, error) { - return vfs.VirtualDentry{}, "", syserror.EINVAL + return vfs.VirtualDentry{}, "", linuxerr.EINVAL } // InodeAttrs partially implements the Inode interface, specifically the diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 1cd3137e6..de046ce1f 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" @@ -318,10 +319,10 @@ func TestDirFDReadWrite(t *testing.T) { defer fd.DecRef(sys.Ctx) // Read/Write should fail for directory FDs. - if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { + if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) { t.Fatalf("Read for directory FD failed with unexpected error: %v", err) } - if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EBADF { + if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); !linuxerr.Equals(linuxerr.EBADF, err) { t.Fatalf("Write for directory FD failed with unexpected error: %v", err) } } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 5504476c8..ed730e215 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 45aa5a494..8fd51e9d0 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -349,7 +350,7 @@ func (d *dentry) copyXattrsLocked(ctx context.Context) error { lowerXattrs, err := vfsObj.ListXattrAt(ctx, d.fs.creds, lowerPop, 0) if err != nil { - if err == syserror.EOPNOTSUPP { + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { // There are no guarantees as to the contents of lowerXattrs. return nil } diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go index df4492346..417a7c630 100644 --- a/pkg/sentry/fsimpl/overlay/directory.go +++ b/pkg/sentry/fsimpl/overlay/directory.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -256,7 +257,7 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in switch whence { case linux.SEEK_SET: if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset == 0 { // Ensure that the next call to fd.IterDirents() calls @@ -268,13 +269,13 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in case linux.SEEK_CUR: offset += fd.off if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Don't clear fd.dirents in this case, even if offset == 0. fd.off = offset return fd.off, nil default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 46c500427..e792677f5 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -218,7 +219,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str Start: parentVD, Path: childPath, }, &vfs.GetDentryOptions{}) - if err == syserror.ENOENT || err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { // The file doesn't exist on this layer. Proceed to the next one. return true } @@ -352,7 +353,7 @@ func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, nam }, &vfs.StatOptions{ Mask: linux.STATX_TYPE, }) - if err == syserror.ENOENT || err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { // The file doesn't exist on this layer. Proceed to the next // one. return true @@ -811,7 +812,7 @@ afterTrailingSymlink: // Determine whether or not we need to create a file. parent.dirMu.Lock() child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err == syserror.ENOENT && mayCreate { + if linuxerr.Equals(linuxerr.ENOENT, err) && mayCreate { fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout) parent.dirMu.Unlock() return fd, err @@ -871,7 +872,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * return nil, syserror.EISDIR } if opts.Flags&linux.O_DIRECT != 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } fd := &directoryFD{} fd.LockFD.Init(&d.locks) @@ -1017,10 +1018,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - return syserror.EINVAL - } - + // Resolve newParent first to verify that it's on this Mount. var ds *[]*dentry fs.renameMu.Lock() defer fs.renameMuUnlockAndCheckDrop(ctx, &ds) @@ -1028,8 +1026,16 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return linuxerr.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -1059,7 +1065,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } if renamed.isDir() { if renamed == newParent || genericIsAncestorDentry(renamed, newParent) { - return syserror.EINVAL + return linuxerr.EINVAL } if oldParent != newParent { if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil { @@ -1089,10 +1095,13 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa whiteouts map[string]bool ) replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil && err != syserror.ENOENT { + if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) { return err } if replaced != nil { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { @@ -1169,7 +1178,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa Root: replaced.upperVD, Start: replaced.upperVD, Path: fspath.Parse(whiteoutName), - }); err != nil && err != syserror.EEXIST { + }); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RenameAt failure: %v", err)) } } @@ -1277,7 +1286,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error defer rp.Mount().EndWrite() name := rp.Component() if name == "." { - return syserror.EINVAL + return linuxerr.EINVAL } if name == ".." { return syserror.ENOTEMPTY @@ -1336,7 +1345,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error Root: child.upperVD, Start: child.upperVD, Path: fspath.Parse(whiteoutName), - }); err != nil && err != syserror.EEXIST { + }); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)) } } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 454c20d4f..4c7243764 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -135,7 +136,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts, ok := fsoptsRaw.(FilesystemOptions) if fsoptsRaw != nil && !ok { ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } vfsroot := vfs.RootFromContext(ctx) if vfsroot.Ok() { @@ -145,7 +146,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if upperPathname, ok := mopts["upperdir"]; ok { if fsopts.UpperRoot.Ok() { ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } delete(mopts, "upperdir") // Linux overlayfs also requires a workdir when upperdir is @@ -154,7 +155,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt upperPath := fspath.Parse(upperPathname) if !upperPath.Absolute { ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ Root: vfsroot, @@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if lowerPathnamesStr, ok := mopts["lowerdir"]; ok { if len(fsopts.LowerRoots) != 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") @@ -189,7 +190,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { ctx.Infof("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ Root: vfsroot, @@ -216,21 +217,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if len(mopts) != 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } if len(fsopts.LowerRoots) == 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK if len(fsopts.LowerRoots) > maxLowerLayers { ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // Take extra references held by the filesystem. @@ -283,7 +284,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Infof("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") root.destroyLocked(ctx) fs.vfsfs.DecRef(ctx) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } root.mode = uint32(rootStat.Mode) root.uid = rootStat.UID diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 43bfd69a3..82491a0f8 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -207,9 +207,10 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e return err } - // Changing owners may clear one or both of the setuid and setgid bits, - // so we may have to update opts before setting d.mode. - if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { + // Changing owners or truncating may clear one or both of the setuid and + // setgid bits, so we may have to update opts before setting d.mode. + inotifyMask := opts.Stat.Mask + if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{ Mask: linux.STATX_MODE, }) @@ -218,10 +219,14 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e } opts.Stat.Mode = stat.Mode opts.Stat.Mask |= linux.STATX_MODE + // Don't generate inotify IN_ATTRIB for size-only changes (truncations). + if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 { + inotifyMask |= linux.STATX_MODE + } } d.updateAfterSetStatLocked(&opts) - if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + if ev := vfs.InotifyEventFromStatMask(inotifyMask); ev != 0 { d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) } return nil diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 2b628bd55..1d3d2d95f 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -81,6 +81,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/refs", @@ -119,6 +120,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", @@ -127,7 +129,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index ce8f55b1f..f2697c12d 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -21,11 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -76,7 +76,7 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF maxCachedDentries, err = strconv.ParseUint(str, 10, 64) if err != nil { ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index b294dfd6a..9187f5b11 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -325,7 +326,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in // the file ..." - user_namespaces(7) srclen := src.NumBytes() if srclen >= hostarch.PageSize || offset != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } b := make([]byte, srclen) if _, err := src.CopyIn(ctx, b); err != nil { @@ -345,7 +346,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in } lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1) if len(lines) > maxIDMapLines { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } entries := make([]auth.IDMapEntry, len(lines)) @@ -353,7 +354,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in var e auth.IDMapEntry _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length) if err != nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } entries[i] = e } @@ -461,10 +462,10 @@ func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, e case linux.SEEK_CUR: offset += fd.offset default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } fd.offset = offset return offset, nil diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 177cb828f..ab47ea5a7 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" @@ -33,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -679,7 +679,7 @@ func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { continue } if err := d.stack.Statistics(stat, line.prefix); err != nil { - if err == syserror.EOPNOTSUPP { + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) } else { log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 045ed7a2d..2def1ca48 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -53,7 +54,7 @@ func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? - return "", syserror.EINVAL + return "", linuxerr.EINVAL } tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) if tgid == 0 { @@ -94,7 +95,7 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? - return "", syserror.EINVAL + return "", linuxerr.EINVAL } tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup()) tid := s.pidns.IDOfTask(t) diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 88ab49048..99f64a9d8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -28,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" ) @@ -55,6 +55,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k * }), }), "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "max_map_count": fs.newInode(ctx, root, 0444, newStaticFile("2147483647\n")), "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}), "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")), }), @@ -208,7 +209,7 @@ func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if src.NumBytes() == 0 { return 0, nil @@ -256,7 +257,7 @@ func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if src.NumBytes() == 0 { return 0, nil @@ -310,7 +311,7 @@ func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error { func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if src.NumBytes() == 0 { return 0, nil @@ -395,7 +396,7 @@ func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if src.NumBytes() == 0 { return 0, nil @@ -448,7 +449,7 @@ func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error { func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if src.NumBytes() == 0 { return 0, nil @@ -466,7 +467,7 @@ func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset i // Port numbers must be uint16s. if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if err := pr.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index e534fbca8..14f806c3c 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -23,13 +23,13 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -227,7 +227,7 @@ func TestTasks(t *testing.T) { defer fd.DecRef(s.Ctx) buf := make([]byte, 1) bufIOSeq := usermem.BytesIOSequence(buf) - if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR { + if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) { t.Errorf("wrong error reading directory: %v", err) } } @@ -237,7 +237,7 @@ func TestTasks(t *testing.T) { s.Creds, s.PathOpAtRoot("/proc/9999"), &vfs.OpenOptions{}, - ); err != syserror.ENOENT { + ); !linuxerr.Equals(linuxerr.ENOENT, err) { t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err) } } diff --git a/pkg/sentry/fsimpl/proc/yama.go b/pkg/sentry/fsimpl/proc/yama.go index e039ec45e..7240563d7 100644 --- a/pkg/sentry/fsimpl/proc/yama.go +++ b/pkg/sentry/fsimpl/proc/yama.go @@ -21,11 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -56,7 +56,7 @@ func (s *yamaPtraceScope) Generate(ctx context.Context, buf *bytes.Buffer) error func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // Ignore partial writes. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if src.NumBytes() == 0 { return 0, nil @@ -73,7 +73,7 @@ func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, off // We do not support YAMA levels > YAMA_SCOPE_RELATIONAL. if v < linux.YAMA_SCOPE_DISABLED || v > linux.YAMA_SCOPE_RELATIONAL { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } atomic.StoreInt32(s.level, v) diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 09043b572..1af0a5cbc 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -26,6 +26,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/coverage", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index b13f141a8..d06aea162 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -17,6 +17,7 @@ package sys import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -85,7 +86,7 @@ func (fd *kcovFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallAr case linux.KCOV_DISABLE: if arg != 0 { // This arg is unused; it should be 0. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } return 0, fd.kcov.DisableTrace(ctx) default: diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 14eb10dcd..546f54a5a 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/coverage" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -74,7 +75,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt maxCachedDentries, err = strconv.ParseUint(str, 10, 64) if err != nil { ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index c766164c7..b3f9d1010 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -17,7 +17,6 @@ go_library( "//pkg/fspath", "//pkg/hostarch", "//pkg/memutil", - "//pkg/metric", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 33e52ce64..473b41cff 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -63,8 +62,6 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } - metric.CreateSentryMetrics() - kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, @@ -83,10 +80,7 @@ func Boot() (*kernel.Kernel, error) { } // Create timekeeper. - tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange()) - if err != nil { - return nil, fmt.Errorf("creating timekeeper: %v", err) - } + tk := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange()) tk.SetClocks(time.NewCalibratedClocks()) creds := auth.NewRootCredentials(auth.NewRootUserNamespace()) @@ -181,7 +175,7 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) { memfile := os.NewFile(uintptr(memfd), memfileName) mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) if err != nil { - memfile.Close() + _ = memfile.Close() return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err) } return mf, nil diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD index 7ce7dc429..e6980a314 100644 --- a/pkg/sentry/fsimpl/timerfd/BUILD +++ b/pkg/sentry/fsimpl/timerfd/BUILD @@ -8,6 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/sentry/kernel/time", "//pkg/sentry/vfs", diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index cbb8b67c5..655a1c76a 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -69,7 +70,7 @@ func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { const sizeofUint64 = 8 if dst.NumBytes() < sizeofUint64 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if val := atomic.SwapUint64(&tfd.val, 0); val != 0 { var buf [sizeofUint64]byte diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index e21fddd7f..ae612aae0 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -58,6 +58,7 @@ go_library( "//pkg/abi/linux", "//pkg/amutex", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/log", @@ -118,6 +119,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/sentry/contexttest", "//pkg/sentry/fs/lock", diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index e8d256495..c25494c0b 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -19,10 +19,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // +stateify savable @@ -196,10 +196,10 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in case linux.SEEK_CUR: offset += fd.off default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // If the offset isn't changing (e.g. due to lseek(0, SEEK_CUR)), don't diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 766289e60..590f7118a 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -300,7 +301,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v case linux.S_IFSOCK: childInode = fs.newSocketFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, opts.Endpoint, parentDir) default: - return syserror.EINVAL + return linuxerr.EINVAL } child := fs.newDentry(childInode) parentDir.insertChildLocked(child, name) @@ -488,7 +489,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st } symlink, ok := d.inode.impl.(*symlink) if !ok { - return "", syserror.EINVAL + return "", linuxerr.EINVAL } symlink.inode.touchAtime(rp.Mount()) return symlink.target, nil @@ -496,20 +497,24 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // TODO(b/145974740): Support renameat2 flags. - return syserror.EINVAL - } - - // Resolve newParent first to verify that it's on this Mount. + // Resolve newParentDir first to verify that it's on this Mount. fs.mu.Lock() defer fs.mu.Unlock() newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + // TODO(b/145974740): Support other renameat2 flags. + return linuxerr.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -537,7 +542,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // mounted filesystem. if renamed.inode.isDir() { if renamed == &newParentDir.dentry || genericIsAncestorDentry(renamed, &newParentDir.dentry) { - return syserror.EINVAL + return linuxerr.EINVAL } if oldParentDir != newParentDir { // Writability is needed to change renamed's "..". @@ -556,6 +561,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } replaced, ok := newParentDir.childMap[newName] if ok { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedDir, ok := replaced.inode.impl.(*directory) if ok { if !renamed.inode.isDir() { @@ -639,7 +647,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error } name := rp.Component() if name == "." { - return syserror.EINVAL + return linuxerr.EINVAL } if name == ".." { return syserror.ENOTEMPTY @@ -815,7 +823,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si if err != nil { return nil, err } - return d.inode.listXattr(size) + return d.inode.listXattr(rp.Credentials(), size) } // GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index 2f856ce36..418c7994e 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -114,7 +115,7 @@ func TestNonblockingWriteError(t *testing.T) { } openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) - if err != syserror.ENXIO { + if !linuxerr.Equals(linuxerr.ENXIO, err) { t.Fatalf("expected ENXIO, but got error: %v", err) } } diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index c45bddff6..0bc1911d9 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -366,7 +367,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs fsmetric.TmpfsReads.Increment() if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since @@ -407,7 +408,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // final offset should be ignored by PWrite. func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, offset, syserror.EINVAL + return 0, offset, linuxerr.EINVAL } // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since @@ -432,7 +433,7 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } if end := offset + srclen; end < offset { // Overflow. - return 0, offset, syserror.EINVAL + return 0, offset, linuxerr.EINVAL } srclen, err = vfs.CheckLimit(ctx, offset, srclen) @@ -476,10 +477,10 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( case linux.SEEK_END: offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size)) default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } fd.off = offset return offset, nil @@ -684,7 +685,7 @@ exitLoop: func GetSeals(fd *vfs.FileDescription) (uint32, error) { f, ok := fd.Impl().(*regularFileFD) if !ok { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } rf := f.inode().impl.(*regularFile) rf.dataMu.RLock() @@ -696,7 +697,7 @@ func GetSeals(fd *vfs.FileDescription) (uint32, error) { func AddSeals(fd *vfs.FileDescription, val uint32) error { f, ok := fd.Impl().(*regularFileFD) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } rf := f.inode().impl.(*regularFile) rf.mapsMu.Lock() diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 9ae25ce9e..bc40aad0d 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -36,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -138,7 +139,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mode, err := strconv.ParseUint(modeStr, 8, 32) if err != nil { ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid mode: %q", modeStr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } rootMode = linux.FileMode(mode & 07777) } @@ -149,12 +150,12 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt uid, err := strconv.ParseUint(uidStr, 10, 32) if err != nil { ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid uid: %q", uidStr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } kuid := creds.UserNamespace.MapToKUID(auth.UID(uid)) if !kuid.Ok() { ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped uid: %d", uid) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } rootKUID = kuid } @@ -165,18 +166,18 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt gid, err := strconv.ParseUint(gidStr, 10, 32) if err != nil { ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid gid: %q", gidStr) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } kgid := creds.UserNamespace.MapToKGID(auth.GID(gid)) if !kgid.Ok() { ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped gid: %d", gid) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } rootKGID = kgid } if len(mopts) != 0 { ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unknown options: %v", mopts) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } devMinor, err := vfsObj.GetAnonBlockDevMinor() @@ -557,7 +558,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs. case *directory: return syserror.EISDIR default: - return syserror.EINVAL + return linuxerr.EINVAL } } if mask&linux.STATX_UID != 0 { @@ -717,44 +718,63 @@ func (i *inode) touchCMtimeLocked() { atomic.StoreInt64(&i.ctime, now) } -func (i *inode) listXattr(size uint64) ([]string, error) { - return i.xattrs.ListXattr(size) +func checkXattrName(name string) error { + // Linux's tmpfs supports "security" and "trusted" xattr namespaces, and + // (depending on build configuration) POSIX ACL xattr namespaces + // ("system.posix_acl_access" and "system.posix_acl_default"). We don't + // support POSIX ACLs or the "security" namespace (b/148380782). + if strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { + return nil + } + // We support the "user" namespace because we have tests that depend on + // this feature. + if strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return nil + } + return syserror.EOPNOTSUPP +} + +func (i *inode) listXattr(creds *auth.Credentials, size uint64) ([]string, error) { + return i.xattrs.ListXattr(creds, size) } func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { - if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { + if err := checkXattrName(opts.Name); err != nil { return "", err } - return i.xattrs.GetXattr(opts) + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + kuid := auth.KUID(atomic.LoadUint32(&i.uid)) + kgid := auth.KGID(atomic.LoadUint32(&i.gid)) + if err := vfs.GenericCheckPermissions(creds, vfs.MayRead, mode, kuid, kgid); err != nil { + return "", err + } + return i.xattrs.GetXattr(creds, mode, kuid, opts) } func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error { - if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { + if err := checkXattrName(opts.Name); err != nil { return err } - return i.xattrs.SetXattr(opts) -} - -func (i *inode) removeXattr(creds *auth.Credentials, name string) error { - if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + kuid := auth.KUID(atomic.LoadUint32(&i.uid)) + kgid := auth.KGID(atomic.LoadUint32(&i.gid)) + if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil { return err } - return i.xattrs.RemoveXattr(name) + return i.xattrs.SetXattr(creds, mode, kuid, opts) } -func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { - // We currently only support extended attributes in the user.* and - // trusted.* namespaces. See b/148380782. - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { - return syserror.EOPNOTSUPP +func (i *inode) removeXattr(creds *auth.Credentials, name string) error { + if err := checkXattrName(name); err != nil { + return err } mode := linux.FileMode(atomic.LoadUint32(&i.mode)) kuid := auth.KUID(atomic.LoadUint32(&i.uid)) kgid := auth.KGID(atomic.LoadUint32(&i.gid)) - if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil { return err } - return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) + return i.xattrs.RemoveXattr(creds, mode, kuid, name) } // fileDescription is embedded by tmpfs implementations of @@ -807,7 +827,7 @@ func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { // ListXattr implements vfs.FileDescriptionImpl.ListXattr. func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { - return fd.inode().listXattr(size) + return fd.inode().listXattr(auth.CredentialsFromContext(ctx), size) } // GetXattr implements vfs.FileDescriptionImpl.GetXattr. diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index d473a922d..1d855234c 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/marshal/primitive", @@ -41,6 +42,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/testutil", @@ -48,7 +50,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 3582d14c9..b5735a86d 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -195,7 +196,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { @@ -218,7 +219,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { @@ -238,7 +239,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { @@ -261,7 +262,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi Root: parent.lowerVD, Start: parent.lowerVD, }, &vfs.StatOptions{}) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { @@ -282,7 +283,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi Mode: uint32(parentStat.Mode), UID: parentStat.UID, GID: parentStat.GID, - Children: parent.childrenNames, + Children: parent.childrenList, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), @@ -327,7 +328,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }, &vfs.OpenOptions{ Flags: linux.O_RDONLY, }) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { @@ -341,7 +342,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Size: sizeOfStringInt32, }) - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { @@ -359,7 +360,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Size: sizeOfStringInt32, }) - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) } if err != nil { @@ -375,7 +376,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Size: sizeOfStringInt32, }) - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) } if err != nil { @@ -403,6 +404,9 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry var buf bytes.Buffer d.hashMu.RLock() + + d.generateChildrenList() + params := &merkletree.VerifyParams{ Out: &buf, Tree: &fdReader, @@ -411,7 +415,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Mode: uint32(stat.Mode), UID: stat.UID, GID: stat.GID, - Children: d.childrenNames, + Children: d.childrenList, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: 0, // Set read size to 0 so only the metadata is verified. @@ -465,7 +469,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s } childVD, err := parent.getLowerAt(ctx, vfsObj, name) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { // The file was previously accessed. If the // file does not exist now, it indicates an // unexpected modification to the file system. @@ -480,7 +484,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // The Merkle tree file was previous accessed. If it // does not exist now, it indicates an unexpected // modification to the file system. - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) } if err != nil { @@ -551,7 +555,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } childVD, err := parent.getLowerAt(ctx, vfsObj, name) - if parent.verityEnabled() && err == syserror.ENOENT { + if parent.verityEnabled() && linuxerr.Equals(linuxerr.ENOENT, err) { return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) } if err != nil { @@ -564,7 +568,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { if parent.verityEnabled() { return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) } @@ -854,7 +858,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // The file should exist, as we succeeded in finding its dentry. If it's // missing, it indicates an unexpected modification to the file system. if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err @@ -877,7 +881,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // dentry. If it's missing, it indicates an unexpected modification to // the file system. if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err @@ -902,7 +906,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf Flags: linux.O_WRONLY | linux.O_APPEND, }) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err @@ -919,7 +923,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf Flags: linux.O_WRONLY | linux.O_APPEND, }) if err != nil { - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index fa7696ad6..2227b542a 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -39,12 +39,14 @@ import ( "encoding/json" "fmt" "math" + "sort" "strconv" "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -251,7 +253,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt hash, err := hex.DecodeString(encodedRootHash) if err != nil { ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } rootHash = hash } @@ -269,19 +271,19 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Check for unparsed options. if len(mopts) != 0 { ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } // Handle internal options. iopts, ok := opts.InternalData.(InternalFilesystemOptions) if len(lowerPathname) == 0 && !ok { ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } if len(lowerPathname) != 0 { if ok { ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } iopts = InternalFilesystemOptions{ AllowRuntimeEnable: len(rootHash) == 0, @@ -300,7 +302,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } var err error mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ @@ -358,7 +360,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // If runtime enable is allowed, the root merkle tree may be absent. We // should create the tree file. - if err == syserror.ENOENT && fs.allowRuntimeEnable { + if linuxerr.Equals(linuxerr.ENOENT, err) && fs.allowRuntimeEnable { lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, Start: lowerVD, @@ -439,7 +441,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if !d.isDir() { ctx.Warningf("verity root must be a directory") - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } if !fs.allowRuntimeEnable { @@ -451,7 +453,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Name: childrenOffsetXattr, Size: sizeOfStringInt32, }) - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) } if err != nil { @@ -470,7 +472,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Name: childrenSizeXattr, Size: sizeOfStringInt32, }) - if err == syserror.ENOENT || err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) { return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) } if err != nil { @@ -487,7 +489,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt }, &vfs.OpenOptions{ Flags: linux.O_RDONLY, }) - if err == syserror.ENOENT { + if linuxerr.Equals(linuxerr.ENOENT, err) { return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) } if err != nil { @@ -508,6 +510,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { return nil, nil, err } + d.generateChildrenList() } d.vfsd.Init(d) @@ -564,6 +567,11 @@ type dentry struct { // populated by enableVerity. childrenNames is also protected by dirMu. childrenNames map[string]struct{} + // childrenList is a complete sorted list of childrenNames. This list + // is generated when verity is enabled, or the first time the file is + // verified in non runtime enable mode. + childrenList []string + // lowerVD is the VirtualDentry in the underlying file system. It is // never modified after initialized. lowerVD vfs.VirtualDentry @@ -749,6 +757,17 @@ func (d *dentry) verityEnabled() bool { return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } +// generateChildrenList generates a sorted childrenList from childrenNames, and +// cache it in d for hashing. +func (d *dentry) generateChildrenList() { + if len(d.childrenList) == 0 && len(d.childrenNames) != 0 { + for child := range d.childrenNames { + d.childrenList = append(d.childrenList, child) + } + sort.Strings(d.childrenList) + } +} + // getLowerAt returns the dentry in the underlying file system, which is // represented by filename relative to d. func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) { @@ -868,6 +887,10 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa fd.mu.Lock() defer fd.mu.Unlock() + if _, err := fd.lowerFD.Seek(ctx, fd.off, linux.SEEK_SET); err != nil { + return err + } + var ds []vfs.Dirent err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { // Do not include the Merkle tree files. @@ -890,8 +913,8 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa return err } - // The result should contain all children plus "." and "..". - if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 { + // The result should be a part of all children plus "." and "..", counting from fd.off. + if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2-int(fd.off) { return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) } @@ -917,14 +940,14 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) case linux.SEEK_END: n = int64(fd.d.size) default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset > math.MaxInt64-n { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } offset += n if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } fd.off = offset return offset, nil @@ -958,10 +981,12 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui return nil, 0, err } + fd.d.generateChildrenList() + params := &merkletree.GenerateParams{ TreeReader: &merkleReader, TreeWriter: &merkleWriter, - Children: fd.d.childrenNames, + Children: fd.d.childrenList, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), Name: fd.d.name, Mode: uint32(stat.Mode), @@ -1003,7 +1028,7 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui default: // TODO(b/167728857): Investigate whether and how we should // enable other types of file. - return nil, 0, syserror.EINVAL + return nil, 0, linuxerr.EINVAL } hash, err := merkletree.Generate(params) return hash, uint64(params.Size), err @@ -1121,7 +1146,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest hostarch.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) if t == nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } var metadata linux.DigestMetadata @@ -1174,7 +1199,7 @@ func (fd *fileDescription) verityFlags(ctx context.Context, flags hostarch.Addr) t := kernel.TaskFromContext(ctx) if t == nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } _, err := primitive.CopyInt32Out(t, flags, f) return 0, err @@ -1223,7 +1248,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { @@ -1257,7 +1282,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of Mode: fd.d.mode, UID: fd.d.uid, GID: fd.d.gid, - Children: fd.d.childrenNames, + Children: fd.d.childrenList, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), @@ -1345,7 +1370,7 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. - if err == syserror.ENODATA { + if linuxerr.Equals(linuxerr.ENODATA, err) { return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { @@ -1429,7 +1454,7 @@ func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) { // mapped region. readOffset := off - int64(r.Offset) if readOffset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } bs.DropFirst64(uint64(readOffset)) view := bs.TakeFirst64(uint64(len(p))) diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 5c78a0019..65465b814 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" @@ -31,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -476,7 +476,7 @@ func TestOpenNonexistentFile(t *testing.T) { // Ensure open an unexpected file in the parent directory fails with // ENOENT rather than verification failure. - if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); err != syserror.ENOENT { + if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.ENOENT, err) { t.Errorf("OpenAt unexpected error: %v", err) } } @@ -767,7 +767,7 @@ func TestOpenDeletedFileFails(t *testing.T) { } // Ensure reopening the verity enabled file fails. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("got OpenAt error: %v, expected EIO", err) } }) @@ -829,7 +829,7 @@ func TestOpenRenamedFileFails(t *testing.T) { } // Ensure reopening the verity enabled file fails. - if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("got OpenAt error: %v, expected EIO", err) } }) @@ -1063,14 +1063,14 @@ func TestDeletedSymlinkFileReadFails(t *testing.T) { Root: root, Start: root, Path: fspath.Parse(symlink), - }); err != syserror.EIO { + }); !linuxerr.Equals(linuxerr.EIO, err) { t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err) } if tc.testWalk { fileInSymlinkDirectory := symlink + "/verity-test-file" // Ensure opening the verity enabled file in the symlink directory fails. - if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Open succeeded with modified symlink: %v", err) } } @@ -1195,14 +1195,14 @@ func TestModifiedSymlinkFileReadFails(t *testing.T) { Root: root, Start: root, Path: fspath.Parse(symlink), - }); err != syserror.EIO { + }); !linuxerr.Equals(linuxerr.EIO, err) { t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err) } if tc.testWalk { fileInSymlinkDirectory := symlink + "/verity-test-file" // Ensure opening the verity enabled file in the symlink directory fails. - if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Open succeeded with modified symlink: %v", err) } } diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go index 7e535b527..17d0d5025 100644 --- a/pkg/sentry/fsmetric/fsmetric.go +++ b/pkg/sentry/fsmetric/fsmetric.go @@ -42,7 +42,6 @@ var ( // Metrics that only apply to fs/gofer and fsimpl/gofer. var ( - GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.") GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.") GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.") GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go index 1cabc848f..e103e7296 100644 --- a/pkg/sentry/hostfd/hostfd_linux.go +++ b/pkg/sentry/hostfd/hostfd_linux.go @@ -14,5 +14,10 @@ package hostfd -// maxIov is the maximum permitted size of a struct iovec array. -const maxIov = 1024 // UIO_MAXIOV +// MaxReadWriteIov is the maximum permitted size of a struct iovec array in a +// readv, writev, preadv, or pwritev host syscall. +const MaxReadWriteIov = 1024 // UIO_MAXIOV + +// MaxSendRecvMsgIov is the maximum permitted size of a struct iovec array in a +// sendmsg or recvmsg host syscall. +const MaxSendRecvMsgIov = 1024 // UIO_MAXIOV diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go index 03c6d2a16..a43311eb4 100644 --- a/pkg/sentry/hostfd/hostfd_unsafe.go +++ b/pkg/sentry/hostfd/hostfd_unsafe.go @@ -23,6 +23,11 @@ import ( "gvisor.dev/gvisor/pkg/safemem" ) +const ( + sizeofIovec = unsafe.Sizeof(unix.Iovec{}) + sizeofMsghdr = unsafe.Sizeof(unix.Msghdr{}) +) + // Preadv2 reads up to dsts.NumBytes() bytes from host file descriptor fd into // dsts. offset and flags are interpreted as for preadv2(2). // @@ -44,9 +49,9 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6 } } else { iovs := safemem.IovecsFromBlockSeq(dsts) - if len(iovs) > maxIov { - log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) - iovs = iovs[:maxIov] + if len(iovs) > MaxReadWriteIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), MaxReadWriteIov) + iovs = iovs[:MaxReadWriteIov] } n, _, e = unix.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } @@ -80,9 +85,9 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint } } else { iovs := safemem.IovecsFromBlockSeq(srcs) - if len(iovs) > maxIov { - log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) - iovs = iovs[:maxIov] + if len(iovs) > MaxReadWriteIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), MaxReadWriteIov) + iovs = iovs[:MaxReadWriteIov] } n, _, e = unix.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a1ec6daab..26614b029 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//pkg/sync:generic_seqatomic", + template = "//pkg/sync/seqatomic:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -218,6 +218,7 @@ go_library( ":uncaught_signal_go_proto", "//pkg/abi", "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/amutex", "//pkg/bits", "//pkg/bpf", @@ -225,6 +226,8 @@ go_library( "//pkg/context", "//pkg/coverage", "//pkg/cpuid", + "//pkg/errors", + "//pkg/errors/linuxerr", "//pkg/eventchannel", "//pkg/fspath", "//pkg/goid", @@ -298,6 +301,7 @@ go_test( deps = [ "//pkg/abi", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/sentry/arch", "//pkg/sentry/contexttest", @@ -309,6 +313,5 @@ go_test( "//pkg/sentry/time", "//pkg/sentry/usage", "//pkg/sync", - "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 869e49ebc..7a1a36454 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//pkg/sync:generic_atomicptr", + template = "//pkg/sync/atomicptr:generic_atomicptr", types = { "Value": "Credentials", }, @@ -63,6 +63,7 @@ go_library( "//pkg/abi/linux", "//pkg/bits", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go index 6862f2ef5..32c344399 100644 --- a/pkg/sentry/kernel/auth/credentials.go +++ b/pkg/sentry/kernel/auth/credentials.go @@ -16,6 +16,7 @@ package auth import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -125,7 +126,7 @@ func NewUserCredentials(kuid KUID, kgid KGID, extraKGIDs []KGID, capabilities *T creds.EffectiveCaps = capabilities.EffectiveCaps creds.BoundingCaps = capabilities.BoundingCaps creds.InheritableCaps = capabilities.InheritableCaps - // TODO(nlacasse): Support ambient capabilities. + // TODO(gvisor.dev/issue/3166): Support ambient capabilities. } else { // If no capabilities are specified, grant capabilities consistent with // setresuid + setresgid from NewRootCredentials to the given uid and @@ -203,7 +204,7 @@ func (c *Credentials) UseUID(uid UID) (KUID, error) { // uid must be mapped. kuid := c.UserNamespace.MapToKUID(uid) if !kuid.Ok() { - return NoID, syserror.EINVAL + return NoID, linuxerr.EINVAL } // If c has CAP_SETUID, then it can use any UID in its user namespace. if c.HasCapability(linux.CAP_SETUID) { @@ -222,7 +223,7 @@ func (c *Credentials) UseUID(uid UID) (KUID, error) { func (c *Credentials) UseGID(gid GID) (KGID, error) { kgid := c.UserNamespace.MapToKGID(gid) if !kgid.Ok() { - return NoID, syserror.EINVAL + return NoID, linuxerr.EINVAL } if c.HasCapability(linux.CAP_SETGID) { return kgid, nil @@ -239,7 +240,7 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) { func (c *Credentials) SetUID(uid UID) error { kuid := c.UserNamespace.MapToKUID(uid) if !kuid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } c.RealKUID = kuid c.EffectiveKUID = kuid @@ -253,7 +254,7 @@ func (c *Credentials) SetUID(uid UID) error { func (c *Credentials) SetGID(gid GID) error { kgid := c.UserNamespace.MapToKGID(gid) if !kgid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } c.RealKGID = kgid c.EffectiveKGID = kgid diff --git a/pkg/sentry/kernel/auth/id_map.go b/pkg/sentry/kernel/auth/id_map.go index 28cbe159d..955b6d40b 100644 --- a/pkg/sentry/kernel/auth/id_map.go +++ b/pkg/sentry/kernel/auth/id_map.go @@ -17,6 +17,7 @@ package auth import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -110,7 +111,7 @@ func (ns *UserNamespace) SetUIDMap(ctx context.Context, entries []IDMapEntry) er } // "At least one line must be written to the file." if len(entries) == 0 { - return syserror.EINVAL + return linuxerr.EINVAL } // """ // In order for a process to write to the /proc/[pid]/uid_map @@ -170,11 +171,11 @@ func (ns *UserNamespace) trySetUIDMap(entries []IDMapEntry) error { // checks for NoID. lastID := e.FirstID + e.Length if lastID <= e.FirstID { - return syserror.EINVAL + return linuxerr.EINVAL } lastParentID := e.FirstParentID + e.Length if lastParentID <= e.FirstParentID { - return syserror.EINVAL + return linuxerr.EINVAL } // "3. The mapped user IDs (group IDs) must in turn have a mapping in // the parent user namespace." @@ -186,10 +187,10 @@ func (ns *UserNamespace) trySetUIDMap(entries []IDMapEntry) error { } // If either of these Adds fail, we have an overlapping range. if !ns.uidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) { - return syserror.EINVAL + return linuxerr.EINVAL } if !ns.uidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) { - return syserror.EINVAL + return linuxerr.EINVAL } } return nil @@ -205,7 +206,7 @@ func (ns *UserNamespace) SetGIDMap(ctx context.Context, entries []IDMapEntry) er return syserror.EPERM } if len(entries) == 0 { - return syserror.EINVAL + return linuxerr.EINVAL } if !c.HasCapabilityIn(linux.CAP_SETGID, ns) { return syserror.EPERM @@ -239,20 +240,20 @@ func (ns *UserNamespace) trySetGIDMap(entries []IDMapEntry) error { for _, e := range entries { lastID := e.FirstID + e.Length if lastID <= e.FirstID { - return syserror.EINVAL + return linuxerr.EINVAL } lastParentID := e.FirstParentID + e.Length if lastParentID <= e.FirstParentID { - return syserror.EINVAL + return linuxerr.EINVAL } if !ns.parent.allIDsMapped(&ns.parent.gidMapToParent, e.FirstParentID, lastParentID) { return syserror.EPERM } if !ns.gidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) { - return syserror.EINVAL + return linuxerr.EINVAL } if !ns.gidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) { - return syserror.EINVAL + return linuxerr.EINVAL } } return nil diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go index 0fbf27f64..c93ef6ac1 100644 --- a/pkg/sentry/kernel/cgroup.go +++ b/pkg/sentry/kernel/cgroup.go @@ -181,7 +181,23 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files for _, h := range r.hierarchies { if h.match(ctypes) { - h.fs.IncRef() + if !h.fs.TryIncRef() { + // Racing with filesystem destruction, namely h.fs.Release. + // Since we hold r.mu, we know the hierarchy hasn't been + // unregistered yet, but its associated filesystem is tearing + // down. + // + // If we simply indicate the hierarchy wasn't found without + // cleaning up the registry, the caller can race with the + // unregister and find itself temporarily unable to create a new + // hierarchy with a subset of the relevant controllers. + // + // To keep the result of FindHierarchy consistent with the + // uniqueness of controllers enforced by Register, drop the + // dying hierarchy now. The eventual unregister by the FS + // teardown will become a no-op. + return nil + } return h.fs } } @@ -230,12 +246,17 @@ func (r *CgroupRegistry) Register(cs []CgroupController, fs cgroupFS) error { return nil } -// Unregister removes a previously registered hierarchy from the registry. If -// the controller was not previously registered, Unregister is a no-op. +// Unregister removes a previously registered hierarchy from the registry. If no +// such hierarchy is registered, Unregister is a no-op. func (r *CgroupRegistry) Unregister(hid uint32) { r.mu.Lock() - defer r.mu.Unlock() + r.unregisterLocked(hid) + r.mu.Unlock() +} +// Precondition: Caller must hold r.mu. +// +checklocks:r.mu +func (r *CgroupRegistry) unregisterLocked(hid uint32) { if h, ok := r.hierarchies[hid]; ok { for name, _ := range h.controllers { delete(r.controllers, name) diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index f855f038b..6b2dd09da 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -8,13 +8,12 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/arch", + "//pkg/errors/linuxerr", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", - "//pkg/syserror", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index dbbbaeeb0..473987a79 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -17,13 +17,12 @@ package fasync import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -125,9 +124,9 @@ func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) { if !permCheck { return } - signalInfo := &arch.SignalInfo{ + signalInfo := &linux.SignalInfo{ Signo: int32(linux.SIGIO), - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } if a.signal != 0 { signalInfo.Signo = int32(a.signal) @@ -249,7 +248,7 @@ func (a *FileAsync) Signal() linux.Signal { // to send SIGIO. func (a *FileAsync) SetSignal(signal linux.Signal) error { if signal != 0 && !signal.IsValid() { - return syserror.EINVAL + return linuxerr.EINVAL } a.mu.Lock() defer a.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 62777faa8..8786a70b5 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -23,12 +23,12 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // FDFlags define flags for an individual descriptor. @@ -156,7 +156,7 @@ func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. if file.SupportsLocks() { err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) - if err != nil && err != syserror.ENOLCK { + if err != nil && !linuxerr.Equals(linuxerr.ENOLCK, err) { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) } } diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index a75686cf3..0606d32a8 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//pkg/sync:generic_atomicptr", + template = "//pkg/sync/atomicptr:generic_atomicptr", types = { "Value": "bucket", }, @@ -37,6 +37,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/sentry/memmap", diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 0427cf3f4..5c64ce11e 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -20,6 +20,7 @@ package futex import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" @@ -332,7 +333,7 @@ func getKey(t Target, addr hostarch.Addr, private bool) (Key, error) { // Ensure the address is aligned. // It must be a DWORD boundary. if addr&0x3 != 0 { - return Key{}, syserror.EINVAL + return Key{}, linuxerr.EINVAL } if private { return Key{Kind: KindPrivate, Offset: uint64(addr)}, nil @@ -790,7 +791,7 @@ func (m *Manager) unlockPILocked(t Target, addr hostarch.Addr, tid uint32, b *bu return err } if prev != cur { - return syserror.EINVAL + return linuxerr.EINVAL } b.wakeWaiterLocked(next) diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go index 4b943106b..941cc373f 100644 --- a/pkg/sentry/kernel/kcov.go +++ b/pkg/sentry/kernel/kcov.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/coverage" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -131,13 +132,13 @@ func (kcov *Kcov) InitTrace(size uint64) error { // To simplify all the logic around mapping, we require that the length of the // shared region is a multiple of the system page size. if (8*size)&(hostarch.PageSize-1) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } // We need space for at least two uint64s to hold current position and a // single PC. if size < 2 || size > kcovAreaSizeMax { - return syserror.EINVAL + return linuxerr.EINVAL } kcov.size = size @@ -157,7 +158,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error { // KCOV_ENABLE must be preceded by KCOV_INIT_TRACE and an mmap call. if kcov.mode != linux.KCOV_MODE_INIT || kcov.mappable == nil { - return syserror.EINVAL + return linuxerr.EINVAL } switch traceKind { @@ -167,7 +168,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error { // We do not support KCOV_MODE_TRACE_CMP. return syserror.ENOTSUP default: - return syserror.EINVAL + return linuxerr.EINVAL } if kcov.owningTask != nil && kcov.owningTask != t { @@ -195,7 +196,7 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error { } if t != kcov.owningTask { - return syserror.EINVAL + return linuxerr.EINVAL } kcov.mode = linux.KCOV_MODE_INIT kcov.owningTask = nil @@ -237,7 +238,7 @@ func (kcov *Kcov) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro defer kcov.mu.Unlock() if kcov.mode != linux.KCOV_MODE_INIT { - return syserror.EINVAL + return linuxerr.EINVAL } if kcov.mappable == nil { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index e6e9da898..352c36ba9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -143,12 +143,6 @@ type Kernel struct { // to CreateProcess, and is protected by extMu. globalInit *ThreadGroup - // realtimeClock is a ktime.Clock based on timekeeper's Realtime. - realtimeClock *timekeeperClock - - // monotonicClock is a ktime.Clock based on timekeeper's Monotonic. - monotonicClock *timekeeperClock - // syslog is the kernel log. syslog syslog @@ -354,19 +348,19 @@ type InitKernelArgs struct { // before calling Init. func (k *Kernel) Init(args InitKernelArgs) error { if args.FeatureSet == nil { - return fmt.Errorf("FeatureSet is nil") + return fmt.Errorf("args.FeatureSet is nil") } if args.Timekeeper == nil { - return fmt.Errorf("Timekeeper is nil") + return fmt.Errorf("args.Timekeeper is nil") } if args.Timekeeper.clocks == nil { return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()") } if args.RootUserNamespace == nil { - return fmt.Errorf("RootUserNamespace is nil") + return fmt.Errorf("args.RootUserNamespace is nil") } if args.ApplicationCores == 0 { - return fmt.Errorf("ApplicationCores is 0") + return fmt.Errorf("args.ApplicationCores is 0") } k.featureSet = args.FeatureSet @@ -395,8 +389,6 @@ func (k *Kernel) Init(args InitKernelArgs) error { } k.extraAuxv = args.ExtraAuxv k.vdso = args.Vdso - k.realtimeClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Realtime} - k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() k.ptraceExceptions = make(map[*Task]*Task) @@ -529,6 +521,8 @@ func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) + // Save the timekeeper's state. + // Save the kernel state. kernelStart := time.Now() stats, err := state.Save(ctx, w, k) @@ -654,12 +648,12 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { defer k.tasks.mu.RUnlock() for t := range k.tasks.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if mm := t.image.MemoryManager; mm != nil { - if _, ok := invalidated[mm]; !ok { - if err := mm.InvalidateUnsavable(ctx); err != nil { + if memMgr := t.image.MemoryManager; memMgr != nil { + if _, ok := invalidated[memMgr]; !ok { + if err := memMgr.InvalidateUnsavable(ctx); err != nil { return err } - invalidated[mm] = struct{}{} + invalidated[memMgr] = struct{}{} } } // I really wish we just had a sync.Map of all MMs... @@ -673,7 +667,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { +func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, timeReady chan struct{}, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -720,6 +714,11 @@ func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, cl log.Infof("Overall load took [%s]", time.Since(loadStart)) k.Timekeeper().SetClocks(clocks) + + if timeReady != nil { + close(timeReady) + } + if net != nil { net.Resume() } @@ -1101,7 +1100,7 @@ func (k *Kernel) Start() error { } k.started = true - k.cpuClockTicker = ktime.NewTimer(k.monotonicClock, newKernelCPUClockTicker(k)) + k.cpuClockTicker = ktime.NewTimer(k.timekeeper.monotonicClock, newKernelCPUClockTicker(k)) k.cpuClockTicker.Swap(ktime.Setting{ Enabled: true, Period: linux.ClockTick, @@ -1256,7 +1255,7 @@ func (k *Kernel) incRunningTasks() { // These cause very different value of cpuClock. But again, since // nothing was running while the ticker was disabled, those differences // don't matter. - setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now()) + setting, exp := k.cpuClockTickerSetting.At(k.timekeeper.monotonicClock.Now()) if exp > 0 { atomic.AddUint64(&k.cpuClock, exp) } @@ -1339,7 +1338,7 @@ func (k *Kernel) Unpause() { // context is used only for debugging to describe how the signal was received. // // Preconditions: Kernel must have an init process. -func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) { +func (k *Kernel) SendExternalSignal(info *linux.SignalInfo, context string) { k.extMu.Lock() defer k.extMu.Unlock() k.sendExternalSignal(info, context) @@ -1347,7 +1346,7 @@ func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) { // SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup. // This function doesn't skip signals like SendExternalSignal does. -func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error { +func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *linux.SignalInfo) error { k.extMu.Lock() defer k.extMu.Unlock() return tg.SendSignal(info) @@ -1355,7 +1354,7 @@ func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.Signa // SendContainerSignal sends the given signal to all processes inside the // namespace that match the given container ID. -func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { +func (k *Kernel) SendContainerSignal(cid string, info *linux.SignalInfo) error { k.extMu.Lock() defer k.extMu.Unlock() k.tasks.mu.RLock() @@ -1466,12 +1465,12 @@ func (k *Kernel) ApplicationCores() uint { // RealtimeClock returns the application CLOCK_REALTIME clock. func (k *Kernel) RealtimeClock() ktime.Clock { - return k.realtimeClock + return k.timekeeper.realtimeClock } // MonotonicClock returns the application CLOCK_MONOTONIC clock. func (k *Kernel) MonotonicClock() ktime.Clock { - return k.monotonicClock + return k.timekeeper.monotonicClock } // CPUClockNow returns the current value of k.cpuClock. @@ -1551,31 +1550,6 @@ func (k *Kernel) SetSaveError(err error) { } } -var _ tcpip.Clock = (*Kernel)(nil) - -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (k *Kernel) NowNanoseconds() int64 { - now, err := k.timekeeper.GetTime(sentrytime.Realtime) - if err != nil { - panic("Kernel.NowNanoseconds: " + err.Error()) - } - return now -} - -// NowMonotonic implements tcpip.Clock.NowMonotonic. -func (k *Kernel) NowMonotonic() int64 { - now, err := k.timekeeper.GetTime(sentrytime.Monotonic) - if err != nil { - panic("Kernel.NowMonotonic: " + err.Error()) - } - return now -} - -// AfterFunc implements tcpip.Clock.AfterFunc. -func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer { - return ktime.TcpipAfterFunc(k.realtimeClock, d, f) -} - // SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or // LoadFrom. func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) { @@ -1783,7 +1757,7 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { }) t := TaskFromContext(ctx) - k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{ + _, _ = k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{ Tid: int32(t.ThreadID()), Registers: t.Arch().StateData().Proto(), }) @@ -1858,7 +1832,9 @@ func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) { return } t.mu.Lock() - t.enterCgroupLocked(root) + // A task can be in the cgroup if it has been created after the + // cgroup hierarchy was registered. + t.enterCgroupIfNotYetLocked(root) t.mu.Unlock() }) k.tasks.mu.RUnlock() @@ -1874,7 +1850,7 @@ func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) { return } t.mu.Lock() - for cg, _ := range t.cgroups { + for cg := range t.cgroups { if cg.HierarchyID() == hid { t.leaveCgroupLocked(cg) } diff --git a/pkg/sentry/kernel/pending_signals.go b/pkg/sentry/kernel/pending_signals.go index 77a35b788..af455c434 100644 --- a/pkg/sentry/kernel/pending_signals.go +++ b/pkg/sentry/kernel/pending_signals.go @@ -17,7 +17,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" - "gvisor.dev/gvisor/pkg/sentry/arch" ) const ( @@ -65,7 +64,7 @@ type pendingSignalQueue struct { type pendingSignal struct { // pendingSignalEntry links into a pendingSignalList. pendingSignalEntry - *arch.SignalInfo + *linux.SignalInfo // If timer is not nil, it is the IntervalTimer which sent this signal. timer *IntervalTimer @@ -75,7 +74,7 @@ type pendingSignal struct { // on failure (if the given signal's queue is full). // // Preconditions: info represents a valid signal. -func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bool { +func (p *pendingSignals) enqueue(info *linux.SignalInfo, timer *IntervalTimer) bool { sig := linux.Signal(info.Signo) q := &p.signals[sig.Index()] if sig.IsStandard() { @@ -93,7 +92,7 @@ func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bo // dequeue dequeues and returns any pending signal not masked by mask. If no // unmasked signals are pending, dequeue returns nil. -func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo { +func (p *pendingSignals) dequeue(mask linux.SignalSet) *linux.SignalInfo { // "Real-time signals are delivered in a guaranteed order. Multiple // real-time signals of the same type are delivered in the order they were // sent. If different real-time signals are sent to a process, they are @@ -111,7 +110,7 @@ func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo { return p.dequeueSpecific(linux.Signal(lowestPendingUnblockedBit + 1)) } -func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *arch.SignalInfo { +func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *linux.SignalInfo { q := &p.signals[sig.Index()] ps := q.pendingSignalList.Front() if ps == nil { diff --git a/pkg/sentry/kernel/pending_signals_state.go b/pkg/sentry/kernel/pending_signals_state.go index ca8b4e164..e77f1a254 100644 --- a/pkg/sentry/kernel/pending_signals_state.go +++ b/pkg/sentry/kernel/pending_signals_state.go @@ -14,13 +14,11 @@ package kernel -import ( - "gvisor.dev/gvisor/pkg/sentry/arch" -) +import "gvisor.dev/gvisor/pkg/abi/linux" // +stateify savable type savedPendingSignal struct { - si *arch.SignalInfo + si *linux.SignalInfo timer *IntervalTimer } diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 34c617b08..94ebac7c5 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/abi/linux", "//pkg/amutex", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal/primitive", "//pkg/safemem", @@ -47,6 +48,7 @@ go_test( library = ":pipe", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/contexttest", "//pkg/sentry/fs", "//pkg/syserror", diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 6497dc4ba..2321d26dc 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -17,6 +17,7 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sync" @@ -130,7 +131,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi return rw, nil default: - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index d6fb0fdb8..d25cf658e 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" @@ -258,7 +259,7 @@ func TestNonblockingWriteOpenFileNoReaders(t *testing.T) { ctx := newSleeperContext(t) f := NewInodeOperations(ctx, perms, newNamedPipe(t)) - if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != syserror.ENXIO { + if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); !linuxerr.Equals(linuxerr.ENXIO, err) { t.Fatalf("Nonblocking open for write failed unexpected error %v.", err) } } diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 06769931a..979ea10bf 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -428,7 +429,7 @@ func (p *Pipe) FifoSize(context.Context, *fs.File) (int64, error) { // SetFifoSize implements fs.FifoSizer.SetFifoSize. func (p *Pipe) SetFifoSize(size int64) (int64, error) { if size < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if size < MinimumPipeSize { size = MinimumPipeSize // Per spec. diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 2d89b9ccd..3fa5d1d2f 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -86,6 +86,12 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) if n > 0 { p.Notify(waiter.ReadableEvents) } + if err == unix.EPIPE { + // If we are returning EPIPE send SIGPIPE to the task. + if sendSig := linux.SignalNoInfoFuncFromContext(ctx); sendSig != nil { + sendSig(linux.SIGPIPE) + } + } return n, err } @@ -129,7 +135,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume v = math.MaxInt32 // Silently truncate. } // Copy result to userspace. - iocc := primitive.IOCopyContext{ + iocc := usermem.IOCopyContext{ IO: io, Ctx: ctx, Opts: usermem.IOOpts{ diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 95b948edb..623375417 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -17,6 +17,7 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -90,7 +91,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s readable := vfs.MayReadFileWithOpenFlags(statusFlags) writable := vfs.MayWriteFileWithOpenFlags(statusFlags) if !readable && !writable { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } fd, err := vp.newFD(mnt, vfsd, statusFlags, locks) @@ -415,7 +416,7 @@ func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) { // Preconditions: count > 0. func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) { if dst.pipe == src.pipe { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } lockTwoPipes(dst.pipe, src.pipe) diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go index 2e861a5a8..049cc07df 100644 --- a/pkg/sentry/kernel/posixtimer.go +++ b/pkg/sentry/kernel/posixtimer.go @@ -18,7 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/errors/linuxerr" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/syserror" ) @@ -97,7 +97,7 @@ func (it *IntervalTimer) ResumeTimer() { } // Preconditions: it.target's signal mutex must be locked. -func (it *IntervalTimer) updateDequeuedSignalLocked(si *arch.SignalInfo) { +func (it *IntervalTimer) updateDequeuedSignalLocked(si *linux.SignalInfo) { it.sigpending = false if it.sigorphan { return @@ -138,9 +138,9 @@ func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Settin it.sigpending = true it.sigorphan = false it.overrunCur += exp - 1 - si := &arch.SignalInfo{ + si := &linux.SignalInfo{ Signo: int32(it.signo), - Code: arch.SignalInfoTimer, + Code: linux.SI_TIMER, } si.SetTimerID(it.id) si.SetSigval(it.sigval) @@ -215,16 +215,16 @@ func (t *Task) IntervalTimerCreate(c ktime.Clock, sigev *linux.Sigevent) (linux. target, ok := t.tg.pidns.tasks[ThreadID(sigev.Tid)] t.tg.pidns.owner.mu.RUnlock() if !ok || target.tg != t.tg { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } it.target = target default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if sigev.Notify != linux.SIGEV_NONE { it.signo = linux.Signal(sigev.Signo) if !it.signo.IsValid() { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } it.timer = ktime.NewTimer(c, it) @@ -239,7 +239,7 @@ func (t *Task) IntervalTimerDelete(id linux.TimerID) error { defer t.tg.timerMu.Unlock() it := t.tg.timers[id] if it == nil { - return syserror.EINVAL + return linuxerr.EINVAL } delete(t.tg.timers, id) it.DestroyTimer() @@ -252,7 +252,7 @@ func (t *Task) IntervalTimerSettime(id linux.TimerID, its linux.Itimerspec, abs defer t.tg.timerMu.Unlock() it := t.tg.timers[id] if it == nil { - return linux.Itimerspec{}, syserror.EINVAL + return linux.Itimerspec{}, linuxerr.EINVAL } newS, err := ktime.SettingFromItimerspec(its, abs, it.timer.Clock()) @@ -270,7 +270,7 @@ func (t *Task) IntervalTimerGettime(id linux.TimerID) (linux.Itimerspec, error) defer t.tg.timerMu.Unlock() it := t.tg.timers[id] if it == nil { - return linux.Itimerspec{}, syserror.EINVAL + return linux.Itimerspec{}, linuxerr.EINVAL } tm, s := it.timer.Get() @@ -286,7 +286,7 @@ func (t *Task) IntervalTimerGetoverrun(id linux.TimerID) (int32, error) { defer t.tg.timerMu.Unlock() it := t.tg.timers[id] if it == nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // By timer_create(2) invariant, either it.target == nil (in which case // it.overrunLast is immutably 0) or t.tg == it.target.tg; and the fact diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 57c7659e7..1c6100efe 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -19,9 +19,9 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -295,7 +295,7 @@ func (t *Task) isYAMADescendantOfLocked(ancestor *Task) bool { // Precondition: the TaskSet mutex must be locked (for reading or writing). func (t *Task) hasYAMAExceptionForLocked(tracer *Task) bool { - allowed, ok := t.k.ptraceExceptions[t] + allowed, ok := t.k.ptraceExceptions[t.tg.leader] if !ok { return false } @@ -394,7 +394,7 @@ func (t *Task) ptraceTrapLocked(code int32) { t.trapStopPending = false t.tg.signalHandlers.mu.Unlock() t.ptraceCode = code - t.ptraceSiginfo = &arch.SignalInfo{ + t.ptraceSiginfo = &linux.SignalInfo{ Signo: int32(linux.SIGTRAP), Code: code, } @@ -402,7 +402,7 @@ func (t *Task) ptraceTrapLocked(code int32) { t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) if t.beginPtraceStopLocked() { tracer := t.Tracer() - tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP)) + tracer.signalStop(t, linux.CLD_TRAPPED, int32(linux.SIGTRAP)) tracer.tg.eventQueue.Notify(EventTraceeStop) } } @@ -542,9 +542,9 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error { // "Unlike PTRACE_ATTACH, PTRACE_SEIZE does not stop the process." - // ptrace(2) if !seize { - target.sendSignalLocked(&arch.SignalInfo{ + target.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGSTOP), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }, false /* group */) } // Undocumented Linux feature: If the tracee is already group-stopped (and @@ -586,7 +586,7 @@ func (t *Task) exitPtrace() { for target := range t.ptraceTracees { if target.ptraceOpts.ExitKill { target.tg.signalHandlers.mu.Lock() - target.sendSignalLocked(&arch.SignalInfo{ + target.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGKILL), }, false /* group */) target.tg.signalHandlers.mu.Unlock() @@ -652,7 +652,7 @@ func (t *Task) forgetTracerLocked() { // Preconditions: // * The signal mutex must be locked. // * The caller must be running on the task goroutine. -func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool { +func (t *Task) ptraceSignalLocked(info *linux.SignalInfo) bool { if linux.Signal(info.Signo) == linux.SIGKILL { return false } @@ -678,7 +678,7 @@ func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool { t.ptraceSiginfo = info t.Debugf("Entering signal-delivery-stop for signal %d", info.Signo) if t.beginPtraceStopLocked() { - tracer.signalStop(t, arch.CLD_TRAPPED, info.Signo) + tracer.signalStop(t, linux.CLD_TRAPPED, info.Signo) tracer.tg.eventQueue.Notify(EventTraceeStop) } return true @@ -829,7 +829,7 @@ func (t *Task) ptraceClone(kind ptraceCloneKind, child *Task, opts *CloneOptions if child.ptraceSeized { child.trapStopPending = true } else { - child.pendingSignals.enqueue(&arch.SignalInfo{ + child.pendingSignals.enqueue(&linux.SignalInfo{ Signo: int32(linux.SIGSTOP), }, nil) } @@ -893,9 +893,9 @@ func (t *Task) ptraceExec(oldTID ThreadID) { } t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() - t.sendSignalLocked(&arch.SignalInfo{ + t.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGTRAP), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }, false /* group */) } @@ -995,7 +995,7 @@ func (t *Task) ptraceSetOptionsLocked(opts uintptr) error { linux.PTRACE_O_TRACEVFORK | linux.PTRACE_O_TRACEVFORKDONE) if opts&^valid != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } t.ptraceOpts = ptraceOptions{ ExitKill: opts&linux.PTRACE_O_EXITKILL != 0, @@ -1222,27 +1222,27 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() if target.ptraceSiginfo == nil { - return syserror.EINVAL + return linuxerr.EINVAL } _, err := target.ptraceSiginfo.CopyOut(t, data) return err case linux.PTRACE_SETSIGINFO: - var info arch.SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(t, data); err != nil { return err } t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() if target.ptraceSiginfo == nil { - return syserror.EINVAL + return linuxerr.EINVAL } target.ptraceSiginfo = &info return nil case linux.PTRACE_GETSIGMASK: if addr != linux.SignalSetSize { - return syserror.EINVAL + return linuxerr.EINVAL } mask := target.SignalMask() _, err := mask.CopyOut(t, data) @@ -1250,7 +1250,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error { case linux.PTRACE_SETSIGMASK: if addr != linux.SignalSetSize { - return syserror.EINVAL + return linuxerr.EINVAL } var mask linux.SignalSet if _, err := mask.CopyIn(t, data); err != nil { diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 4bc5bca44..2344565cd 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/hostcpu" "gvisor.dev/gvisor/pkg/syserror" @@ -59,20 +60,20 @@ func (t *Task) RSeqAvailable() bool { func (t *Task) SetRSeq(addr hostarch.Addr, length, signature uint32) error { if t.rseqAddr != 0 { if t.rseqAddr != addr { - return syserror.EINVAL + return linuxerr.EINVAL } if t.rseqSignature != signature { - return syserror.EINVAL + return linuxerr.EINVAL } return syserror.EBUSY } // rseq must be aligned and correctly sized. if addr&(linux.AlignOfRSeq-1) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } if length != linux.SizeOfRSeq { - return syserror.EINVAL + return linuxerr.EINVAL } if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok { return syserror.EFAULT @@ -103,13 +104,13 @@ func (t *Task) SetRSeq(addr hostarch.Addr, length, signature uint32) error { // Preconditions: The caller must be running on the task goroutine. func (t *Task) ClearRSeq(addr hostarch.Addr, length, signature uint32) error { if t.rseqAddr == 0 { - return syserror.EINVAL + return linuxerr.EINVAL } if t.rseqAddr != addr { - return syserror.EINVAL + return linuxerr.EINVAL } if length != linux.SizeOfRSeq { - return syserror.EINVAL + return linuxerr.EINVAL } if t.rseqSignature != signature { return syserror.EPERM @@ -152,10 +153,10 @@ func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error { return nil } if r.CriticalSection.Start >= r.CriticalSection.End { - return syserror.EINVAL + return linuxerr.EINVAL } if r.CriticalSection.Contains(r.Restart) { - return syserror.EINVAL + return linuxerr.EINVAL } // TODO(jamieliu): check that r.CriticalSection and r.Restart are in // the application address range, for consistency with Linux. @@ -187,7 +188,7 @@ func (t *Task) SetOldRSeqCPUAddr(addr hostarch.Addr) error { // unfortunate, but unlikely in a correct program. if err := t.rseqUpdateCPU(); err != nil { t.oldRSeqCPUAddr = 0 - return syserror.EINVAL // yes, EINVAL, not err or EFAULT + return linuxerr.EINVAL // yes, EINVAL, not err or EFAULT } return nil } diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index a95e174a2..54ca43c2e 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -39,11 +39,11 @@ func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input { } } -func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *arch.SignalInfo { - si := &arch.SignalInfo{ +func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *linux.SignalInfo { + si := &linux.SignalInfo{ Signo: int32(linux.SIGSYS), Errno: errno, - Code: arch.SYS_SECCOMP, + Code: linux.SYS_SECCOMP, } si.SetCallAddr(uint64(ip)) si.SetSyscall(sysno) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 65e5427c1..a787c00a8 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -25,6 +25,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index fe2ab1662..2dbc8353a 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -35,10 +36,10 @@ const ( // Maximum number of semaphore sets. setsMax = linux.SEMMNI - // Maximum number of semaphroes in a semaphore set. + // Maximum number of semaphores in a semaphore set. semsMax = linux.SEMMSL - // Maximum number of semaphores in all semaphroe sets. + // Maximum number of semaphores in all semaphore sets. semsTotalMax = linux.SEMMNS ) @@ -127,7 +128,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { // exists. func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) { if nsems < 0 || nsems > semsMax { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } r.mu.Lock() @@ -147,7 +148,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu // Validate parameters. if nsems > int32(set.Size()) { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } if create && exclusive { return nil, syserror.EEXIST @@ -163,7 +164,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu // Zero is only valid if an existing set is found. if nsems == 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } // Apply system limits. @@ -171,10 +172,10 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu // Map semaphores and map indexes in a registry are of the same size, // check map semaphores only here for the system limit. if len(r.semaphores) >= setsMax { - return nil, syserror.EINVAL + return nil, syserror.ENOSPC } if r.totalSems() > int(semsTotalMax-nsems) { - return nil, syserror.EINVAL + return nil, syserror.ENOSPC } // Finally create a new set. @@ -220,7 +221,7 @@ func (r *Registry) HighestIndex() int32 { defer r.mu.Unlock() // By default, highest used index is 0 even though - // there is no semaphroe set. + // there is no semaphore set. var highestIndex int32 for index := range r.indexes { if index > highestIndex { @@ -238,7 +239,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { set := r.semaphores[id] if set == nil { - return syserror.EINVAL + return linuxerr.EINVAL } index, found := r.findIndexByID(id) if !found { @@ -702,7 +703,9 @@ func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool { return s.checkCapability(creds) } -// destroy destroys the set. Caller must hold 's.mu'. +// destroy destroys the set. +// +// Preconditions: Caller must hold 's.mu'. func (s *Set) destroy() { // Notify all waiters. They will fail on the next attempt to execute // operations and return error. diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 0cd9e2533..973d708a3 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -16,7 +16,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" ) @@ -233,7 +232,7 @@ func (pg *ProcessGroup) Session() *Session { // SendSignal sends a signal to all processes inside the process group. It is // analagous to kernel/signal.c:kill_pgrp. -func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error { +func (pg *ProcessGroup) SendSignal(info *linux.SignalInfo) error { tasks := pg.originator.TaskSet() tasks.mu.RLock() defer tasks.mu.RUnlock() @@ -370,6 +369,11 @@ func (tg *ThreadGroup) CreateProcessGroup() error { // Get the ID for this thread in the current namespace. id := tg.pidns.tgids[tg] + // Check whether a process still exists or not. + if id == 0 { + return syserror.ESRCH + } + // Per above, check for a Session leader or existing group. for s := tg.pidns.owner.sessions.Front(); s != nil; s = s.Next() { if s.leader.pidns != tg.pidns { diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 1c3c0794f..5b69333fe 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/refs", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index a73f1bdca..7a6e91004 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -38,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -145,7 +146,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui // // Note that 'private' always implies the creation of a new segment // whether IPC_CREAT is specified or not. - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } r.mu.Lock() @@ -175,7 +176,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui if size > shm.size { // "A segment for the given key exists, but size is greater than // the size of that segment." - man shmget(2) - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } if create && exclusive { @@ -200,7 +201,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui if val, ok := hostarch.Addr(size).RoundUp(); ok { sizeAligned = uint64(val) } else { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } if numPages := sizeAligned / hostarch.PageSize; r.totalPages+numPages > linux.SHMALL { @@ -652,7 +653,7 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error { uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID)) gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID)) if !uid.Ok() || !gid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } // User may only modify the lower 9 bits of the mode. All the other bits are diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go index 2488ae7d5..e08474d25 100644 --- a/pkg/sentry/kernel/signal.go +++ b/pkg/sentry/kernel/signal.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" ) @@ -36,7 +35,7 @@ const SignalPanic = linux.SIGUSR2 // context is used only for debugging to differentiate these cases. // // Preconditions: Kernel must have an init process. -func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) { +func (k *Kernel) sendExternalSignal(info *linux.SignalInfo, context string) { switch linux.Signal(info.Signo) { case linux.SIGURG: // Sent by the Go 1.14+ runtime for asynchronous goroutine preemption. @@ -60,18 +59,18 @@ func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) { } // SignalInfoPriv returns a SignalInfo equivalent to Linux's SEND_SIG_PRIV. -func SignalInfoPriv(sig linux.Signal) *arch.SignalInfo { - return &arch.SignalInfo{ +func SignalInfoPriv(sig linux.Signal) *linux.SignalInfo { + return &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } } // SignalInfoNoInfo returns a SignalInfo equivalent to Linux's SEND_SIG_NOINFO. -func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo { - info := &arch.SignalInfo{ +func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *linux.SignalInfo { + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go index 768fda220..147cc41bb 100644 --- a/pkg/sentry/kernel/signal_handlers.go +++ b/pkg/sentry/kernel/signal_handlers.go @@ -16,7 +16,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" ) @@ -30,14 +29,14 @@ type SignalHandlers struct { mu sync.Mutex `state:"nosave"` // actions is the action to be taken upon receiving each signal. - actions map[linux.Signal]arch.SignalAct + actions map[linux.Signal]linux.SigAction } // NewSignalHandlers returns a new SignalHandlers specifying all default // actions. func NewSignalHandlers() *SignalHandlers { return &SignalHandlers{ - actions: make(map[linux.Signal]arch.SignalAct), + actions: make(map[linux.Signal]linux.SigAction), } } @@ -59,9 +58,9 @@ func (sh *SignalHandlers) CopyForExec() *SignalHandlers { sh.mu.Lock() defer sh.mu.Unlock() for sig, act := range sh.actions { - if act.Handler == arch.SignalActIgnore { - sh2.actions[sig] = arch.SignalAct{ - Handler: arch.SignalActIgnore, + if act.Handler == linux.SIG_IGN { + sh2.actions[sig] = linux.SigAction{ + Handler: linux.SIG_IGN, } } } @@ -73,15 +72,15 @@ func (sh *SignalHandlers) IsIgnored(sig linux.Signal) bool { sh.mu.Lock() defer sh.mu.Unlock() sa, ok := sh.actions[sig] - return ok && sa.Handler == arch.SignalActIgnore + return ok && sa.Handler == linux.SIG_IGN } // dequeueActionLocked returns the SignalAct that should be used to handle sig. // // Preconditions: sh.mu must be locked. -func (sh *SignalHandlers) dequeueAction(sig linux.Signal) arch.SignalAct { +func (sh *SignalHandlers) dequeueAction(sig linux.Signal) linux.SigAction { act := sh.actions[sig] - if act.IsResetHandler() { + if act.Flags&linux.SA_RESETHAND != 0 { delete(sh.actions, sig) } return act diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 76d472292..1110ecca5 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index f58ec4194..47958e2d4 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -18,6 +18,7 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -64,7 +65,7 @@ func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) { t := kernel.TaskFromContext(ctx) if t == nil { // No task context? Not valid. - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } // name matches fs/signalfd.c:signalfd4. dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]") diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index be1371855..d211e4d82 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -33,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -151,7 +150,7 @@ type Task struct { // which the SA_ONSTACK flag is set. // // signalStack is exclusive to the task goroutine. - signalStack arch.SignalStack + signalStack linux.SignalStack // signalQueue is a set of registered waiters for signal-related events. // @@ -395,7 +394,7 @@ type Task struct { // ptraceSiginfo is analogous to Linux's task_struct::last_siginfo. // // ptraceSiginfo is protected by the TaskSet mutex. - ptraceSiginfo *arch.SignalInfo + ptraceSiginfo *linux.SignalInfo // ptraceEventMsg is the value set by PTRACE_EVENT stops and returned to // the tracer by ptrace(PTRACE_GETEVENTMSG). @@ -847,21 +846,19 @@ func (t *Task) OOMScoreAdj() int32 { // value should be between -1000 and 1000 inclusive. func (t *Task) SetOOMScoreAdj(adj int32) error { if adj > 1000 || adj < -1000 { - return syserror.EINVAL + return linuxerr.EINVAL } atomic.StoreInt32(&t.tg.oomScoreAdj, adj) return nil } -// UID returns t's uid. -// TODO(gvisor.dev/issue/170): This method is not namespaced yet. -func (t *Task) UID() uint32 { +// KUID returns t's kuid. +func (t *Task) KUID() uint32 { return uint32(t.Credentials().EffectiveKUID) } -// GID returns t's gid. -// TODO(gvisor.dev/issue/170): This method is not namespaced yet. -func (t *Task) GID() uint32 { +// KGID returns t's kgid. +func (t *Task) KGID() uint32 { return uint32(t.Credentials().EffectiveKGID) } diff --git a/pkg/sentry/kernel/task_acct.go b/pkg/sentry/kernel/task_acct.go index e574997f7..dd364ae50 100644 --- a/pkg/sentry/kernel/task_acct.go +++ b/pkg/sentry/kernel/task_acct.go @@ -18,10 +18,10 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" ) // Getitimer implements getitimer(2). @@ -44,7 +44,7 @@ func (t *Task) Getitimer(id int32) (linux.ItimerVal, error) { s, _ = t.tg.itimerProfSetting.At(tm) t.tg.signalHandlers.mu.Unlock() default: - return linux.ItimerVal{}, syserror.EINVAL + return linux.ItimerVal{}, linuxerr.EINVAL } val, iv := ktime.SpecFromSetting(tm, s) return linux.ItimerVal{ @@ -105,7 +105,7 @@ func (t *Task) Setitimer(id int32, newitv linux.ItimerVal) (linux.ItimerVal, err return linux.ItimerVal{}, err } default: - return linux.ItimerVal{}, syserror.EINVAL + return linux.ItimerVal{}, linuxerr.EINVAL } oldval, oldiv := ktime.SpecFromSetting(tm, olds) return linux.ItimerVal{ diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index ecbe8f920..07533d982 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -19,6 +19,7 @@ import ( "runtime/trace" "time" + "gvisor.dev/gvisor/pkg/errors/linuxerr" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -45,7 +46,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time. err := t.BlockWithDeadline(C, true, deadline) // Timeout, explicitly return a remaining duration of 0. - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, err } diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go index 25d2504fa..7c138e80f 100644 --- a/pkg/sentry/kernel/task_cgroup.go +++ b/pkg/sentry/kernel/task_cgroup.go @@ -85,6 +85,14 @@ func (t *Task) enterCgroupLocked(c Cgroup) { c.Enter(t) } +// +checklocks:t.mu +func (t *Task) enterCgroupIfNotYetLocked(c Cgroup) { + if _, ok := t.cgroups[c]; ok { + return + } + t.enterCgroupLocked(c) +} + // LeaveCgroups removes t out from all its cgroups. func (t *Task) LeaveCgroups() { t.mu.Lock() diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 405771f3f..76fb0e2cb 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserror" @@ -142,25 +143,25 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { // address, any set of signal handlers must refer to the same address // space. if !opts.NewSignalHandlers && opts.NewAddressSpace { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // In order for the behavior of thread-group-directed signals to be sane, // all tasks in a thread group must share signal handlers. if !opts.NewThreadGroup && opts.NewSignalHandlers { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // All tasks in a thread group must be in the same PID namespace. if !opts.NewThreadGroup && (opts.NewPIDNamespace || t.childPIDNamespace != nil) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // The two different ways of specifying a new PID namespace are // incompatible. if opts.NewPIDNamespace && t.childPIDNamespace != nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Thread groups and FS contexts cannot span user namespaces. if opts.NewUserNamespace && (!opts.NewThreadGroup || !opts.NewFSContext) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Pull task registers and FPU state, a cloned task will inherit the @@ -463,14 +464,14 @@ func (t *Task) Unshare(opts *SharingOptions) error { // sense that clone(2) allows a task to share signal handlers and address // spaces with tasks in other thread groups. if opts.NewAddressSpace || opts.NewSignalHandlers { - return syserror.EINVAL + return linuxerr.EINVAL } creds := t.Credentials() if opts.NewThreadGroup { t.tg.signalHandlers.mu.Lock() if t.tg.tasksCount != 1 { t.tg.signalHandlers.mu.Unlock() - return syserror.EINVAL + return linuxerr.EINVAL } t.tg.signalHandlers.mu.Unlock() // This isn't racy because we're the only living task, and therefore diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 70b0699dc..c82d9e82b 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -17,6 +17,7 @@ package kernel import ( "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -113,6 +114,10 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { return t.k.RealtimeClock() case limits.CtxLimits: return t.tg.limits + case linux.CtxSignalNoInfoFunc: + return func(sig linux.Signal) error { + return t.SendSignal(SignalInfoNoInfo(sig, t, t)) + } case pgalloc.CtxMemoryFile: return t.k.mf case pgalloc.CtxMemoryFileProvider: diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index d9897e802..cf8571262 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -66,7 +66,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -181,7 +180,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tg.signalHandlers = t.tg.signalHandlers.CopyForExec() t.endStopCond.L = &t.tg.signalHandlers.mu // "Any alternate signal stack is not preserved (sigaltstack(2))." - execve(2) - t.signalStack = arch.SignalStack{Flags: arch.SignalStackFlagDisable} + t.signalStack = linux.SignalStack{Flags: linux.SS_DISABLE} // "The termination signal is reset to SIGCHLD (see clone(2))." t.tg.terminationSignal = linux.SIGCHLD // execed indicates that the process can no longer join a process group diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index b1af1a7ef..d115b8783 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -28,9 +28,9 @@ import ( "errors" "fmt" "strconv" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" @@ -50,6 +50,23 @@ type ExitStatus struct { Signo int } +func (es ExitStatus) String() string { + var b strings.Builder + if code := es.Code; code != 0 { + if b.Len() != 0 { + b.WriteByte(' ') + } + _, _ = fmt.Fprintf(&b, "Code=%d", code) + } + if signal := es.Signo; signal != 0 { + if b.Len() != 0 { + b.WriteByte(' ') + } + _, _ = fmt.Fprintf(&b, "Signal=%d", signal) + } + return b.String() +} + // Signaled returns true if the ExitStatus indicates that the exiting task or // thread group was killed by a signal. func (es ExitStatus) Signaled() bool { @@ -122,12 +139,12 @@ func (t *Task) killLocked() { if t.stop != nil && t.stop.Killable() { t.endInternalStopLocked() } - t.pendingSignals.enqueue(&arch.SignalInfo{ + t.pendingSignals.enqueue(&linux.SignalInfo{ Signo: int32(linux.SIGKILL), // Linux just sets SIGKILL in the pending signal bitmask without // enqueueing an actual siginfo, such that // kernel/signal.c:collect_signal() initializes si_code to SI_USER. - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }, nil) t.interrupt() } @@ -332,7 +349,7 @@ func (t *Task) exitThreadGroup() bool { // signalStop must be called with t's signal mutex unlocked. t.tg.signalHandlers.mu.Unlock() if notifyParent && t.tg.leader.parent != nil { - t.tg.leader.parent.signalStop(t, arch.CLD_STOPPED, int32(sig)) + t.tg.leader.parent.signalStop(t, linux.CLD_STOPPED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop) } return last @@ -353,7 +370,7 @@ func (t *Task) exitChildren() { continue } other.signalHandlers.mu.Lock() - other.leader.sendSignalLocked(&arch.SignalInfo{ + other.leader.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGKILL), }, true /* group */) other.signalHandlers.mu.Unlock() @@ -368,9 +385,9 @@ func (t *Task) exitChildren() { // wait for a parent to reap them.) for c := range t.children { if sig := c.ParentDeathSignal(); sig != 0 { - siginfo := &arch.SignalInfo{ + siginfo := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } siginfo.SetPID(int32(c.tg.pidns.tids[t])) siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) @@ -652,10 +669,10 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) { t.parent.tg.signalHandlers.mu.Lock() if t.tg.terminationSignal == linux.SIGCHLD || fromPtraceDetach { if act, ok := t.parent.tg.signalHandlers.actions[linux.SIGCHLD]; ok { - if act.Handler == arch.SignalActIgnore { + if act.Handler == linux.SIG_IGN { t.exitParentAcked = true signalParent = false - } else if act.Flags&arch.SignalFlagNoCldWait != 0 { + } else if act.Flags&linux.SA_NOCLDWAIT != 0 { t.exitParentAcked = true } } @@ -705,17 +722,17 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) { } // Preconditions: The TaskSet mutex must be locked. -func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.SignalInfo { - info := &arch.SignalInfo{ +func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *linux.SignalInfo { + info := &linux.SignalInfo{ Signo: int32(sig), } info.SetPID(int32(receiver.tg.pidns.tids[t])) info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) if t.exitStatus.Signaled() { - info.Code = arch.CLD_KILLED + info.Code = linux.CLD_KILLED info.SetStatus(int32(t.exitStatus.Signo)) } else { - info.Code = arch.CLD_EXITED + info.Code = linux.CLD_EXITED info.SetStatus(int32(t.exitStatus.Code)) } // TODO(b/72102453): Set utime, stime. diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index 0325967e4..29f154ebd 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" @@ -47,7 +48,7 @@ func (t *Task) HasCapability(cp linux.Capability) bool { func (t *Task) SetUID(uid auth.UID) error { // setuid considers -1 to be invalid. if !uid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } t.mu.Lock() @@ -56,7 +57,7 @@ func (t *Task) SetUID(uid auth.UID) error { creds := t.Credentials() kuid := creds.UserNamespace.MapToKUID(uid) if !kuid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } // "setuid() sets the effective user ID of the calling process. If the // effective UID of the caller is root (more precisely: if the caller has @@ -87,14 +88,14 @@ func (t *Task) SetREUID(r, e auth.UID) error { if r.Ok() { newR = creds.UserNamespace.MapToKUID(r) if !newR.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } } newE := creds.EffectiveKUID if e.Ok() { newE = creds.UserNamespace.MapToKUID(e) if !newE.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } } if !creds.HasCapability(linux.CAP_SETUID) { @@ -223,7 +224,7 @@ func (t *Task) setKUIDsUncheckedLocked(newR, newE, newS auth.KUID) { // SetGID implements the semantics of setgid(2). func (t *Task) SetGID(gid auth.GID) error { if !gid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } t.mu.Lock() @@ -232,7 +233,7 @@ func (t *Task) SetGID(gid auth.GID) error { creds := t.Credentials() kgid := creds.UserNamespace.MapToKGID(gid) if !kgid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } if creds.HasCapability(linux.CAP_SETGID) { t.setKGIDsUncheckedLocked(kgid, kgid, kgid) @@ -255,14 +256,14 @@ func (t *Task) SetREGID(r, e auth.GID) error { if r.Ok() { newR = creds.UserNamespace.MapToKGID(r) if !newR.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } } newE := creds.EffectiveKGID if e.Ok() { newE = creds.UserNamespace.MapToKGID(e) if !newE.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } } if !creds.HasCapability(linux.CAP_SETGID) { @@ -349,7 +350,7 @@ func (t *Task) SetExtraGIDs(gids []auth.GID) error { for i, gid := range gids { kgid := creds.UserNamespace.MapToKGID(gid) if !kgid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } kgids[i] = kgid } diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go index bd5543d4e..c132c27ef 100644 --- a/pkg/sentry/kernel/task_image.go +++ b/pkg/sentry/kernel/task_image.go @@ -17,7 +17,7 @@ package kernel import ( "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -27,7 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" ) -var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) +var errNoSyscalls = syserr.New("no syscall table found", errno.ENOEXEC) // Auxmap contains miscellaneous data for the task. type Auxmap map[string]interface{} diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 9ba5f8d78..9d9fa76a6 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -23,12 +23,12 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/hostcpu" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" ) // TaskGoroutineState is a coarse representation of the current execution @@ -536,7 +536,7 @@ func (tg *ThreadGroup) updateCPUTimersEnabledLocked() { // appropriate for /proc/[pid]/status. func (t *Task) StateStatus() string { switch s := t.TaskGoroutineSchedInfo().State; s { - case TaskGoroutineNonexistent: + case TaskGoroutineNonexistent, TaskGoroutineRunningSys: t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() switch t.exitState { @@ -546,16 +546,16 @@ func (t *Task) StateStatus() string { return "X (dead)" default: // The task goroutine can't exit before passing through - // runExitNotify, so this indicates that the task has been created, - // but the task goroutine hasn't yet started. The Linux equivalent - // is struct task_struct::state == TASK_NEW + // runExitNotify, so if s == TaskGoroutineNonexistent, the task has + // been created but the task goroutine hasn't yet started. The + // Linux equivalent is struct task_struct::state == TASK_NEW // (kernel/fork.c:copy_process() => // kernel/sched/core.c:sched_fork()), but the TASK_NEW bit is // masked out by TASK_REPORT for /proc/[pid]/status, leaving only // TASK_RUNNING. return "R (running)" } - case TaskGoroutineRunningSys, TaskGoroutineRunningApp: + case TaskGoroutineRunningApp: return "R (running)" case TaskGoroutineBlockedInterruptible: return "S (sleeping)" @@ -601,7 +601,7 @@ func (t *Task) SetCPUMask(mask sched.CPUSet) error { // Ensure that at least 1 CPU is still allowed. if mask.NumCPUs() == 0 { - return syserror.EINVAL + return linuxerr.EINVAL } if t.k.useHostCores { diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index c2b9fc08f..f54c774cb 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -22,6 +22,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -86,7 +87,7 @@ var defaultActions = map[linux.Signal]SignalAction{ } // computeAction figures out what to do given a signal number -// and an arch.SignalAct. SIGSTOP always results in a SignalActionStop, +// and an linux.SigAction. SIGSTOP always results in a SignalActionStop, // and SIGKILL always results in a SignalActionTerm. // Signal 0 is always ignored as many programs use it for various internal functions // and don't expect it to do anything. @@ -97,7 +98,7 @@ var defaultActions = map[linux.Signal]SignalAction{ // 0, the default action is taken; // 1, the signal is ignored; // anything else, the function returns SignalActionHandler. -func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction { +func computeAction(sig linux.Signal, act linux.SigAction) SignalAction { switch sig { case linux.SIGSTOP: return SignalActionStop @@ -108,9 +109,9 @@ func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction { } switch act.Handler { - case arch.SignalActDefault: + case linux.SIG_DFL: return defaultActions[sig] - case arch.SignalActIgnore: + case linux.SIG_IGN: return SignalActionIgnore default: return SignalActionHandler @@ -127,7 +128,7 @@ var StopSignals = linux.MakeSignalSet(linux.SIGSTOP, linux.SIGTSTP, linux.SIGTTI // If there are no pending unmasked signals, dequeueSignalLocked returns nil. // // Preconditions: t.tg.signalHandlers.mu must be locked. -func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *arch.SignalInfo { +func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *linux.SignalInfo { if info := t.pendingSignals.dequeue(mask); info != nil { return info } @@ -155,7 +156,7 @@ func (t *Task) PendingSignals() linux.SignalSet { } // deliverSignal delivers the given signal and returns the following run state. -func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunState { +func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRunState { sigact := computeAction(linux.Signal(info.Signo), act) if t.haveSyscallReturn { @@ -172,7 +173,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS fallthrough case sre == syserror.ERESTART_RESTARTBLOCK: fallthrough - case (sre == syserror.ERESTARTSYS && !act.IsRestart()): + case (sre == syserror.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0): t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1))) default: @@ -236,7 +237,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS // deliverSignalToHandler changes the task's userspace state to enter the given // user-configured handler for the given signal. -func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) error { +func (t *Task) deliverSignalToHandler(info *linux.SignalInfo, act linux.SigAction) error { // Signal delivery to an application handler interrupts restartable // sequences. t.rseqInterrupt() @@ -248,8 +249,8 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // N.B. This is a *copy* of the alternate stack that the user's signal // handler expects to see in its ucontext (even if it's not in use). alt := t.signalStack - if act.IsOnStack() && alt.IsEnabled() { - alt.SetOnStack() + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() { + alt.Flags |= linux.SS_ONSTACK if !alt.Contains(sp) { sp = hostarch.Addr(alt.Top()) } @@ -289,7 +290,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // Add our signal mask. newMask := t.signalMask | act.Mask - if !act.IsNoDefer() { + if act.Flags&linux.SA_NODEFER == 0 { newMask |= linux.SignalSetOf(linux.Signal(info.Signo)) } t.SetSignalMask(newMask) @@ -326,7 +327,7 @@ func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) { // Preconditions: // * The caller must be running on the task goroutine. // * t.exitState < TaskExitZombie. -func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) { +func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*linux.SignalInfo, error) { // set is the set of signals we're interested in; invert it to get the set // of signals to block. mask := ^(set &^ UnblockableSignals) @@ -370,10 +371,10 @@ func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.S // The following errors may be returned: // // syserror.ESRCH - The task has exited. -// syserror.EINVAL - The signal is not valid. +// linuxerr.EINVAL - The signal is not valid. // syserror.EAGAIN - THe signal is realtime, and cannot be queued. // -func (t *Task) SendSignal(info *arch.SignalInfo) error { +func (t *Task) SendSignal(info *linux.SignalInfo) error { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() t.tg.signalHandlers.mu.Lock() @@ -382,7 +383,7 @@ func (t *Task) SendSignal(info *arch.SignalInfo) error { } // SendGroupSignal sends the given signal to t's thread group. -func (t *Task) SendGroupSignal(info *arch.SignalInfo) error { +func (t *Task) SendGroupSignal(info *linux.SignalInfo) error { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() t.tg.signalHandlers.mu.Lock() @@ -392,7 +393,7 @@ func (t *Task) SendGroupSignal(info *arch.SignalInfo) error { // SendSignal sends the given signal to tg, using tg's leader to determine if // the signal is blocked. -func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error { +func (tg *ThreadGroup) SendSignal(info *linux.SignalInfo) error { tg.pidns.owner.mu.RLock() defer tg.pidns.owner.mu.RUnlock() tg.signalHandlers.mu.Lock() @@ -400,11 +401,11 @@ func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error { return tg.leader.sendSignalLocked(info, true /* group */) } -func (t *Task) sendSignalLocked(info *arch.SignalInfo, group bool) error { +func (t *Task) sendSignalLocked(info *linux.SignalInfo, group bool) error { return t.sendSignalTimerLocked(info, group, nil) } -func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *IntervalTimer) error { +func (t *Task) sendSignalTimerLocked(info *linux.SignalInfo, group bool, timer *IntervalTimer) error { if t.exitState == TaskExitDead { return syserror.ESRCH } @@ -413,7 +414,7 @@ func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *I return nil } if !sig.IsValid() { - return syserror.EINVAL + return linuxerr.EINVAL } // Signal side effects apply even if the signal is ultimately discarded. @@ -572,9 +573,9 @@ func (t *Task) forceSignal(sig linux.Signal, unconditional bool) { func (t *Task) forceSignalLocked(sig linux.Signal, unconditional bool) { blocked := linux.SignalSetOf(sig)&t.signalMask != 0 act := t.tg.signalHandlers.actions[sig] - ignored := act.Handler == arch.SignalActIgnore + ignored := act.Handler == linux.SIG_IGN if blocked || ignored || unconditional { - act.Handler = arch.SignalActDefault + act.Handler = linux.SIG_DFL t.tg.signalHandlers.actions[sig] = act if blocked { t.setSignalMaskLocked(t.signalMask &^ linux.SignalSetOf(sig)) @@ -641,17 +642,17 @@ func (t *Task) SetSavedSignalMask(mask linux.SignalSet) { } // SignalStack returns the task-private signal stack. -func (t *Task) SignalStack() arch.SignalStack { +func (t *Task) SignalStack() linux.SignalStack { t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) alt := t.signalStack if t.onSignalStack(alt) { - alt.Flags |= arch.SignalStackFlagOnStack + alt.Flags |= linux.SS_ONSTACK } return alt } // onSignalStack returns true if the task is executing on the given signal stack. -func (t *Task) onSignalStack(alt arch.SignalStack) bool { +func (t *Task) onSignalStack(alt linux.SignalStack) bool { sp := hostarch.Addr(t.Arch().Stack()) return alt.Contains(sp) } @@ -661,30 +662,30 @@ func (t *Task) onSignalStack(alt arch.SignalStack) bool { // This value may not be changed if the task is currently executing on the // signal stack, i.e. if t.onSignalStack returns true. In this case, this // function will return false. Otherwise, true is returned. -func (t *Task) SetSignalStack(alt arch.SignalStack) bool { +func (t *Task) SetSignalStack(alt linux.SignalStack) bool { // Check that we're not executing on the stack. if t.onSignalStack(t.signalStack) { return false } - if alt.Flags&arch.SignalStackFlagDisable != 0 { + if alt.Flags&linux.SS_DISABLE != 0 { // Don't record anything beyond the flags. - t.signalStack = arch.SignalStack{ - Flags: arch.SignalStackFlagDisable, + t.signalStack = linux.SignalStack{ + Flags: linux.SS_DISABLE, } } else { // Mask out irrelevant parts: only disable matters. - alt.Flags &= arch.SignalStackFlagDisable + alt.Flags &= linux.SS_DISABLE t.signalStack = alt } return true } -// SetSignalAct atomically sets the thread group's signal action for signal sig +// SetSigAction atomically sets the thread group's signal action for signal sig // to *actptr (if actptr is not nil) and returns the old signal action. -func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (arch.SignalAct, error) { +func (tg *ThreadGroup) SetSigAction(sig linux.Signal, actptr *linux.SigAction) (linux.SigAction, error) { if !sig.IsValid() { - return arch.SignalAct{}, syserror.EINVAL + return linux.SigAction{}, linuxerr.EINVAL } tg.pidns.owner.mu.RLock() @@ -695,7 +696,7 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a oldact := sh.actions[sig] if actptr != nil { if sig == linux.SIGKILL || sig == linux.SIGSTOP { - return oldact, syserror.EINVAL + return oldact, linuxerr.EINVAL } act := *actptr @@ -718,48 +719,6 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a return oldact, nil } -// CopyOutSignalAct converts the given SignalAct into an architecture-specific -// type and then copies it out to task memory. -func (t *Task) CopyOutSignalAct(addr hostarch.Addr, s *arch.SignalAct) error { - n := t.Arch().NewSignalAct() - n.SerializeFrom(s) - _, err := n.CopyOut(t, addr) - return err -} - -// CopyInSignalAct copies an architecture-specific sigaction type from task -// memory and then converts it into a SignalAct. -func (t *Task) CopyInSignalAct(addr hostarch.Addr) (arch.SignalAct, error) { - n := t.Arch().NewSignalAct() - var s arch.SignalAct - if _, err := n.CopyIn(t, addr); err != nil { - return s, err - } - n.DeserializeTo(&s) - return s, nil -} - -// CopyOutSignalStack converts the given SignalStack into an -// architecture-specific type and then copies it out to task memory. -func (t *Task) CopyOutSignalStack(addr hostarch.Addr, s *arch.SignalStack) error { - n := t.Arch().NewSignalStack() - n.SerializeFrom(s) - _, err := n.CopyOut(t, addr) - return err -} - -// CopyInSignalStack copies an architecture-specific stack_t from task memory -// and then converts it into a SignalStack. -func (t *Task) CopyInSignalStack(addr hostarch.Addr) (arch.SignalStack, error) { - n := t.Arch().NewSignalStack() - var s arch.SignalStack - if _, err := n.CopyIn(t, addr); err != nil { - return s, err - } - n.DeserializeTo(&s) - return s, nil -} - // groupStop is a TaskStop placed on tasks that have received a stop signal // (SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU). (The term "group-stop" originates from // the ptrace man page.) @@ -774,7 +733,7 @@ func (*groupStop) Killable() bool { return true } // previously-dequeued stop signal. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) initiateGroupStop(info *arch.SignalInfo) { +func (t *Task) initiateGroupStop(info *linux.SignalInfo) { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() t.tg.signalHandlers.mu.Lock() @@ -909,8 +868,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() act, ok := t.tg.signalHandlers.actions[linux.SIGCHLD] - if !ok || (act.Handler != arch.SignalActIgnore && act.Flags&arch.SignalFlagNoCldStop == 0) { - sigchld := &arch.SignalInfo{ + if !ok || (act.Handler != linux.SIG_IGN && act.Flags&linux.SA_NOCLDSTOP == 0) { + sigchld := &linux.SignalInfo{ Signo: int32(linux.SIGCHLD), Code: code, } @@ -955,14 +914,14 @@ func (*runInterrupt) execute(t *Task) taskRunState { // notified its tracer accordingly. But it's consistent with // Linux... if intr { - tracer.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig)) + tracer.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig)) if !notifyParent { tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop | EventChildGroupStop) } else { tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop) } } else { - tracer.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig)) + tracer.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig)) tracer.tg.eventQueue.Notify(EventGroupContinue) } } @@ -974,10 +933,10 @@ func (*runInterrupt) execute(t *Task) taskRunState { // SIGCHLD is a standard signal, so the latter would always be // dropped. Hence sending only the former is equivalent. if intr { - t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig)) + t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue | EventChildGroupStop) } else { - t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig)) + t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue) } } @@ -1018,7 +977,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { // without requiring an extra PTRACE_GETSIGINFO call." - // "Group-stop", ptrace(2) t.ptraceCode = int32(sig) | linux.PTRACE_EVENT_STOP<<8 - t.ptraceSiginfo = &arch.SignalInfo{ + t.ptraceSiginfo = &linux.SignalInfo{ Signo: int32(sig), Code: t.ptraceCode, } @@ -1029,7 +988,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { t.ptraceSiginfo = nil } if t.beginPtraceStopLocked() { - tracer.signalStop(t, arch.CLD_STOPPED, int32(sig)) + tracer.signalStop(t, linux.CLD_STOPPED, int32(sig)) // For consistency with Linux, if the parent and tracer are in the // same thread group, deduplicate notification signals. if notifyParent && tracer.tg == t.tg.leader.parent.tg { @@ -1047,7 +1006,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { t.tg.signalHandlers.mu.Unlock() } if notifyParent { - t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig)) + t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop) } t.tg.pidns.owner.mu.RUnlock() @@ -1101,7 +1060,7 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { if sig != linux.Signal(info.Signo) { info.Signo = int32(sig) info.Errno = 0 - info.Code = arch.SignalInfoUser + info.Code = linux.SI_USER // pid isn't a valid field for all signal numbers, but Linux // doesn't care (kernel/signal.c:ptrace_signal()). // diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 32031cd70..41fd2d471 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" @@ -131,7 +130,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { runState: (*runApp)(nil), interruptChan: make(chan struct{}, 1), signalMask: cfg.SignalMask, - signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable}, + signalStack: linux.SignalStack{Flags: linux.SS_DISABLE}, image: *image, fsContext: cfg.FSContext, fdTable: cfg.FDTable, diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 601fc0d3a..409b712d8 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -22,6 +22,8 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/errors" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/metric" @@ -357,7 +359,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle t.Arch().SetReturn(uintptr(rval)) } else { t.Debugf("vsyscall %d, caller %x: emulated syscall returned error: %v", sysno, t.Arch().Value(caller), err) - if err == syserror.EFAULT { + if linuxerr.Equals(linuxerr.EFAULT, err) { t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) // A return is not emulated in this case. @@ -379,6 +381,8 @@ func ExtractErrno(err error, sysno int) int { return 0 case unix.Errno: return int(err) + case *errors.Error: + return int(err.Errno()) case syserror.SyscallRestartErrno: return int(err) case *memmap.BusError: diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index fc6d9438a..7935d15a6 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" @@ -202,7 +203,7 @@ func (t *Task) CopyInIovecs(addr hostarch.Addr, numIovecs int) (hostarch.AddrRan base := hostarch.Addr(hostarch.ByteOrder.Uint64(b[0:8])) length := hostarch.ByteOrder.Uint64(b[8:16]) if length > math.MaxInt64 { - return hostarch.AddrRangeSeq{}, syserror.EINVAL + return hostarch.AddrRangeSeq{}, linuxerr.EINVAL } ar, ok := t.MemoryManager().CheckIORange(base, int64(length)) if !ok { @@ -270,7 +271,7 @@ func (t *Task) SingleIOSequence(addr hostarch.Addr, length int, opts usermem.IOO // Preconditions: Same as Task.CopyInIovecs. func (t *Task) IovecsIOSequence(addr hostarch.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) { if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV { - return usermem.IOSequence{}, syserror.EINVAL + return usermem.IOSequence{}, linuxerr.EINVAL } ars, err := t.CopyInIovecs(addr, iovcnt) if err != nil { diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index b92e98fa1..8ae00c649 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -19,7 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -279,7 +279,7 @@ func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, s limits: limits, mounts: mntns, } - tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg}) + tg.itimerRealTimer = ktime.NewTimer(k.timekeeper.monotonicClock, &itimerRealListener{tg: tg}) tg.timers = make(map[linux.TimerID]*IntervalTimer) tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{}) return tg @@ -358,7 +358,7 @@ func (tg *ThreadGroup) SetControllingTTY(tty *TTY, steal bool, isReadable bool) // "The calling process must be a session leader and not have a // controlling terminal already." - tty_ioctl(4) if tg.processGroup.session.leader != tg || tg.tty != nil { - return syserror.EINVAL + return linuxerr.EINVAL } creds := auth.CredentialsFromContext(tg.leader) @@ -446,10 +446,10 @@ func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error { othertg.signalHandlers.mu.Lock() othertg.tty = nil if othertg.processGroup == tg.processGroup.session.foreground { - if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { + if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { lastErr = err } - if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { + if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { lastErr = err } } @@ -490,10 +490,10 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) tg.signalHandlers.mu.Lock() defer tg.signalHandlers.mu.Unlock() - // TODO(b/129283598): "If tcsetpgrp() is called by a member of a - // background process group in its session, and the calling process is - // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all - // members of this background process group." + // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a + // background process group in its session, and the calling process is not + // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of + // this background process group." // tty must be the controlling terminal. if tg.tty != tty { @@ -502,7 +502,7 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) // pgid must be positive. if pgid < 0 { - return -1, syserror.EINVAL + return -1, linuxerr.EINVAL } // pg must not be empty. Empty process groups are removed from their diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 2817aa3ba..e293d9a0f 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -13,8 +13,8 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sync", - "//pkg/syserror", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index f61a8e164..191b92811 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -22,8 +22,8 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -322,7 +322,7 @@ func SettingFromSpec(value time.Duration, interval time.Duration, c Clock) (Sett // interpreted as a time relative to now. func SettingFromSpecAt(value time.Duration, interval time.Duration, now Time) (Setting, error) { if value < 0 { - return Setting{}, syserror.EINVAL + return Setting{}, linuxerr.EINVAL } if value == 0 { return Setting{Period: interval}, nil @@ -338,7 +338,7 @@ func SettingFromSpecAt(value time.Duration, interval time.Duration, now Time) (S // interpreted as an absolute time. func SettingFromAbsSpec(value Time, interval time.Duration) (Setting, error) { if value.Before(ZeroTime) { - return Setting{}, syserror.EINVAL + return Setting{}, linuxerr.EINVAL } if value.IsZero() { return Setting{Period: interval}, nil @@ -458,25 +458,6 @@ func NewTimer(clock Clock, listener TimerListener) *Timer { return t } -// After waits for the duration to elapse according to clock and then sends a -// notification on the returned channel. The timer is started immediately and -// will fire exactly once. The second return value is the start time used with -// the duration. -// -// Callers must call Timer.Destroy. -func After(clock Clock, duration time.Duration) (*Timer, Time, <-chan struct{}) { - notifier, tchan := NewChannelNotifier() - t := NewTimer(clock, notifier) - now := clock.Now() - - t.Swap(Setting{ - Enabled: true, - Period: 0, - Next: now.Add(duration), - }) - return t, now, tchan -} - // init initializes Timer state that is not preserved across save/restore. If // init has already been called, calling it again is a no-op. // diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 7c4fefb16..6255bae7a 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" ) // Timekeeper manages all of the kernel clocks. @@ -39,6 +40,12 @@ type Timekeeper struct { // It is set only once, by SetClocks. clocks sentrytime.Clocks `state:"nosave"` + // realtimeClock is a ktime.Clock based on timekeeper's Realtime. + realtimeClock *timekeeperClock + + // monotonicClock is a ktime.Clock based on timekeeper's Monotonic. + monotonicClock *timekeeperClock + // bootTime is the realtime when the system "booted". i.e., when // SetClocks was called in the initial (not restored) run. bootTime ktime.Time @@ -90,10 +97,13 @@ type Timekeeper struct { // NewTimekeeper does not take ownership of paramPage. // // SetClocks must be called on the returned Timekeeper before it is usable. -func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) { - return &Timekeeper{ +func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) *Timekeeper { + t := Timekeeper{ params: NewVDSOParamPage(mfp, paramPage), - }, nil + } + t.realtimeClock = &timekeeperClock{tk: &t, c: sentrytime.Realtime} + t.monotonicClock = &timekeeperClock{tk: &t, c: sentrytime.Monotonic} + return &t } // SetClocks the backing clock source. @@ -167,6 +177,32 @@ func (t *Timekeeper) SetClocks(c sentrytime.Clocks) { } } +var _ tcpip.Clock = (*Timekeeper)(nil) + +// Now implements tcpip.Clock. +func (t *Timekeeper) Now() time.Time { + nsec, err := t.GetTime(sentrytime.Realtime) + if err != nil { + panic("timekeeper.GetTime(sentrytime.Realtime): " + err.Error()) + } + return time.Unix(0, nsec) +} + +// NowMonotonic implements tcpip.Clock. +func (t *Timekeeper) NowMonotonic() tcpip.MonotonicTime { + nsec, err := t.GetTime(sentrytime.Monotonic) + if err != nil { + panic("timekeeper.GetTime(sentrytime.Monotonic): " + err.Error()) + } + var mt tcpip.MonotonicTime + return mt.Add(time.Duration(nsec) * time.Nanosecond) +} + +// AfterFunc implements tcpip.Clock. +func (t *Timekeeper) AfterFunc(d time.Duration, f func()) tcpip.Timer { + return ktime.TcpipAfterFunc(t.realtimeClock, d, f) +} + // startUpdater starts an update goroutine that keeps the clocks updated. // // mu must be held. diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go index dfc3c0719..b6039505a 100644 --- a/pkg/sentry/kernel/timekeeper_test.go +++ b/pkg/sentry/kernel/timekeeper_test.go @@ -17,12 +17,12 @@ package kernel import ( "testing" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/pgalloc" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" ) // mockClocks is a sentrytime.Clocks that simply returns the times in the @@ -45,7 +45,7 @@ func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) { case sentrytime.Realtime: return c.realtime, nil default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 4c65215fa..54bfed644 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -17,8 +17,10 @@ go_library( deps = [ "//pkg/abi", "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/context", "//pkg/cpuid", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/rand", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 8fc3e2a79..13ab7ea23 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -517,13 +518,13 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in start, ok = start.AddLength(uint64(offset)) if !ok { ctx.Infof(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset)) - return loadedELF{}, syserror.EINVAL + return loadedELF{}, linuxerr.EINVAL } end, ok = end.AddLength(uint64(offset)) if !ok { ctx.Infof(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset)) - return loadedELF{}, syserror.EINVAL + return loadedELF{}, linuxerr.EINVAL } info.entry, ok = info.entry.AddLength(uint64(offset)) @@ -621,7 +622,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) { info, err := parseHeader(ctx, f) if err != nil { - if err == syserror.ENOEXEC { + if linuxerr.Equals(linuxerr.ENOEXEC, err) { // Bad interpreter. err = syserror.ELIBBAD } diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go index 3886b4d33..3e302d92c 100644 --- a/pkg/sentry/loader/interpreter.go +++ b/pkg/sentry/loader/interpreter.go @@ -59,7 +59,7 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil // Linux silently truncates the remainder of the line if it exceeds // interpMaxLineLength. i := bytes.IndexByte(line, '\n') - if i > 0 { + if i >= 0 { line = line[:i] } diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 47e3775a3..8240173ae 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" @@ -237,7 +238,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V // loaded.end is available for its use. e, ok := loaded.end.RoundUp() if !ok { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), errno.ENOEXEC) } args.MemoryManager.BrkSetup(ctx, e) diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index fd54261fd..054ef1723 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" @@ -58,7 +59,7 @@ type byteFullReader struct { func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset >= int64(len(b.data)) { return 0, io.EOF diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b417c2da7..69aff21b6 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -125,6 +125,7 @@ go_library( "//pkg/abi/linux", "//pkg/atomicbitops", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/refs", @@ -156,6 +157,7 @@ go_test( library = ":mm", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/sentry/arch", "//pkg/sentry/contexttest", @@ -163,7 +165,6 @@ go_test( "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", - "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 346866d3c..8426fc90e 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -17,6 +17,7 @@ package mm import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -158,7 +159,7 @@ func (ctx *AIOContext) Prepare() error { defer ctx.mu.Unlock() if ctx.dead { // Context died after the caller looked it up. - return syserror.EINVAL + return linuxerr.EINVAL } if ctx.outstanding >= ctx.maxOutstanding { // Context is busy. @@ -297,7 +298,7 @@ func (m *aioMappable) InodeID() uint64 { // Msync implements memmap.MappingIdentity.Msync. func (m *aioMappable) Msync(ctx context.Context, mr memmap.MappableRange) error { // Linux: aio_ring_fops.fsync == NULL - return syserror.EINVAL + return linuxerr.EINVAL } // AddMapping implements memmap.Mappable.AddMapping. @@ -325,7 +326,7 @@ func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, s // Linux's fs/aio.c:aio_ring_mremap(). mm, ok := ms.(*MemoryManager) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } am := &mm.aioManager am.mu.Lock() @@ -333,12 +334,12 @@ func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, s oldID := uint64(srcAR.Start) aioCtx, ok := am.contexts[oldID] if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } aioCtx.mu.Lock() defer aioCtx.mu.Unlock() if aioCtx.dead { - return syserror.EINVAL + return linuxerr.EINVAL } // Use the new ID for the AIOContext. am.contexts[uint64(dstAR.Start)] = aioCtx @@ -399,7 +400,7 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint id := uint64(addr) if !mm.aioManager.newAIOContext(events, id) { mm.MUnmap(ctx, addr, aioRingBufferSize) - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } return id, nil } diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index 1304b0a2f..84cb8158d 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/contexttest" @@ -25,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -171,7 +171,7 @@ func TestIOAfterUnmap(t *testing.T) { } n, err = mm.CopyIn(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { + if !linuxerr.Equals(linuxerr.EFAULT, err) { t.Errorf("CopyIn got err %v want EFAULT", err) } if n != 0 { @@ -212,7 +212,7 @@ func TestIOAfterMProtect(t *testing.T) { // Without IgnorePermissions, CopyOut should no longer succeed. n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{}) - if err != syserror.EFAULT { + if !linuxerr.Equals(linuxerr.EFAULT, err) { t.Errorf("CopyOut got err %v want EFAULT", err) } if n != 0 { @@ -249,7 +249,7 @@ func TestAIOPrepareAfterDestroy(t *testing.T) { mm.DestroyAIOContext(ctx, id) // Prepare should fail because aioCtx should be destroyed. - if err := aioCtx.Prepare(); err != syserror.EINVAL { + if err := aioCtx.Prepare(); !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("aioCtx.Prepare got err %v want nil", err) } else if err == nil { aioCtx.CancelPendingRequest() diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go index 3130be80c..94d5112a1 100644 --- a/pkg/sentry/mm/shm.go +++ b/pkg/sentry/mm/shm.go @@ -16,16 +16,16 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/shm" - "gvisor.dev/gvisor/pkg/syserror" ) // DetachShm unmaps a sysv shared memory segment. func (mm *MemoryManager) DetachShm(ctx context.Context, addr hostarch.Addr) error { if addr != addr.RoundDown() { // "... shmaddr is not aligned on a page boundary." - man shmdt(2) - return syserror.EINVAL + return linuxerr.EINVAL } var detached *shm.Shm @@ -48,7 +48,7 @@ func (mm *MemoryManager) DetachShm(ctx context.Context, addr hostarch.Addr) erro if detached == nil { // There is no shared memory segment attached at addr. - return syserror.EINVAL + return linuxerr.EINVAL } // Remove all vmas that could have been created by the same attach. diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index e748b7ff8..feafe76c1 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -16,6 +16,7 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -144,11 +145,11 @@ func (m *SpecialMappable) Length() uint64 { // leak (b/143656263). Delete this function along with VFS1. func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) { if length == 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } alignedLen, ok := hostarch.Addr(length).RoundUp() if !ok { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } fr, err := mfp.MemoryFile().Allocate(uint64(alignedLen), usage.Anonymous) if err != nil { diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 7ad6b7c21..7b6715815 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" @@ -74,7 +75,7 @@ func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr hostarch.Addr // MMap establishes a memory mapping. func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostarch.Addr, error) { if opts.Length == 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } length, ok := hostarch.Addr(opts.Length).RoundUp() if !ok { @@ -85,7 +86,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar if opts.Mappable != nil { // Offset must be aligned. if hostarch.Addr(opts.Offset).RoundDown() != hostarch.Addr(opts.Offset) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Offset + length must not overflow. if end := opts.Offset + opts.Length; end < opts.Offset { @@ -99,7 +100,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar // MAP_FIXED requires addr to be page-aligned; non-fixed mappings // don't. if opts.Fixed { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } opts.Addr = opts.Addr.RoundDown() } @@ -108,10 +109,10 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar return 0, syserror.EACCES } if opts.Unmap && !opts.Fixed { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if opts.GrowsDown && opts.Mappable != nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Get the new vma. @@ -281,18 +282,18 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, erro // MUnmap implements the semantics of Linux's munmap(2). func (mm *MemoryManager) MUnmap(ctx context.Context, addr hostarch.Addr, length uint64) error { if addr != addr.RoundDown() { - return syserror.EINVAL + return linuxerr.EINVAL } if length == 0 { - return syserror.EINVAL + return linuxerr.EINVAL } la, ok := hostarch.Addr(length).RoundUp() if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } ar, ok := addr.ToRange(uint64(la)) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } mm.mappingMu.Lock() @@ -331,7 +332,7 @@ const ( func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (hostarch.Addr, error) { // "Note that old_address has to be page aligned." - mremap(2) if oldAddr.RoundDown() != oldAddr { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Linux treats an old_size that rounds up to 0 as 0, which is otherwise a @@ -340,13 +341,13 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS oldSize = uint64(oldSizeAddr) newSizeAddr, ok := hostarch.Addr(newSize).RoundUp() if !ok || newSizeAddr == 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } newSize = uint64(newSizeAddr) oldEnd, ok := oldAddr.AddLength(oldSize) if !ok { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } mm.mappingMu.Lock() @@ -450,15 +451,15 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS case MRemapMustMove: newAddr := opts.NewAddr if newAddr.RoundDown() != newAddr { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } var ok bool newAR, ok = newAddr.ToRange(newSize) if !ok { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if (hostarch.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Check that the new region is valid. @@ -504,7 +505,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS if vma := vseg.ValuePtr(); vma.mappable != nil { // Check that offset+length does not overflow. if vma.off+uint64(newAR.Length()) < vma.off { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Inform the Mappable, if any, of the new mapping. if err := vma.mappable.CopyMapping(ctx, mm, oldAR, newAR, vseg.mappableOffsetAt(oldAR.Start), vma.canWriteMappableLocked()); err != nil { @@ -590,7 +591,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS // MProtect implements the semantics of Linux's mprotect(2). func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms hostarch.AccessType, growsDown bool) error { if addr.RoundDown() != addr { - return syserror.EINVAL + return linuxerr.EINVAL } if length == 0 { return nil @@ -618,7 +619,7 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h } if growsDown { if !vseg.ValuePtr().growsDown { - return syserror.EINVAL + return linuxerr.EINVAL } if ar.End <= vseg.Start() { return syserror.ENOMEM @@ -711,7 +712,7 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr hostarch.Addr) (hostarch. if addr < mm.brk.Start { addr = mm.brk.End mm.mappingMu.Unlock() - return addr, syserror.EINVAL + return addr, linuxerr.EINVAL } // TODO(gvisor.dev/issue/156): This enforces RLIMIT_DATA, but is @@ -780,7 +781,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u la, _ := hostarch.Addr(length + addr.PageOffset()).RoundUp() ar, ok := addr.RoundDown().ToRange(uint64(la)) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } mm.mappingMu.Lock() @@ -855,10 +856,10 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u mm.activeMu.Unlock() mm.mappingMu.RUnlock() // Linux: mm/mlock.c:__mlock_posix_error_return() - if err == syserror.EFAULT { + if linuxerr.Equals(linuxerr.EFAULT, err) { return syserror.ENOMEM } - if err == syserror.ENOMEM { + if linuxerr.Equals(linuxerr.ENOMEM, err) { return syserror.EAGAIN } return err @@ -898,7 +899,7 @@ type MLockAllOpts struct { // depending on opts. func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error { if !opts.Current && !opts.Future { - return syserror.EINVAL + return linuxerr.EINVAL } mm.mappingMu.Lock() @@ -979,13 +980,13 @@ func (mm *MemoryManager) NumaPolicy(addr hostarch.Addr) (linux.NumaPolicy, uint6 // SetNumaPolicy implements the semantics of Linux's mbind(). func (mm *MemoryManager) SetNumaPolicy(addr hostarch.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error { if !addr.IsPageAligned() { - return syserror.EINVAL + return linuxerr.EINVAL } // Linux allows this to overflow. la, _ := hostarch.Addr(length).RoundUp() ar, ok := addr.ToRange(uint64(la)) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } if ar.Length() == 0 { return nil @@ -1021,7 +1022,7 @@ func (mm *MemoryManager) SetNumaPolicy(addr hostarch.Addr, length uint64, policy func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork bool) error { ar, ok := addr.ToRange(length) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } mm.mappingMu.Lock() @@ -1047,7 +1048,7 @@ func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error { ar, ok := addr.ToRange(length) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } mm.mappingMu.RLock() @@ -1063,7 +1064,7 @@ func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error { for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() { vma := vseg.ValuePtr() if vma.mlockMode != memmap.MLockNone { - return syserror.EINVAL + return linuxerr.EINVAL } vsegAR := vseg.Range().Intersect(ar) // pseg should already correspond to either this vma or a later one, @@ -1114,7 +1115,7 @@ type MSyncOpts struct { // MSync implements the semantics of Linux's msync(). func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length uint64, opts MSyncOpts) error { if addr != addr.RoundDown() { - return syserror.EINVAL + return linuxerr.EINVAL } if length == 0 { return nil diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index b81292c46..d1a883da4 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -1062,10 +1062,20 @@ func (f *MemoryFile) runReclaim() { break } - // If ManualZeroing is in effect, pages will be zeroed on allocation - // and may not be freed by decommitFile, so calling decommitFile is - // unnecessary. - if !f.opts.ManualZeroing { + if f.opts.ManualZeroing { + // If ManualZeroing is in effect, only hugepage-aligned regions may + // be safely passed to decommitFile. Pages will be zeroed on + // reallocation, so we don't need to perform any manual zeroing + // here, whether or not decommitFile succeeds. + if startAddr, ok := hostarch.Addr(fr.Start).HugeRoundUp(); ok { + if endAddr := hostarch.Addr(fr.End).HugeRoundDown(); startAddr < endAddr { + decommitFR := memmap.FileRange{uint64(startAddr), uint64(endAddr)} + if err := f.decommitFile(decommitFR); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", decommitFR, err) + } + } + } + } else { if err := f.decommitFile(fr); err != nil { log.Warningf("Reclaim failed to decommit %v: %v", fr, err) // Zero the pages manually. This won't reduce memory usage, but at diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index b307832fd..8a490b3de 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -6,6 +6,8 @@ go_library( name = "kvm", srcs = [ "address_space.go", + "address_space_amd64.go", + "address_space_arm64.go", "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", @@ -77,6 +79,7 @@ go_test( "requires-kvm", ], deps = [ + "//pkg/abi/linux", "//pkg/hostarch", "//pkg/ring0", "//pkg/ring0/pagetables", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index 5524e8727..9929caebb 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -85,15 +85,6 @@ type addressSpace struct { dirtySet *dirtySet } -// invalidate is the implementation for Invalidate. -func (as *addressSpace) invalidate() { - as.dirtySet.forEach(as.machine, func(c *vCPU) { - if c.active.get() == as { // If this happens to be active, - c.BounceToKernel() // ... force a kernel transition. - } - }) -} - // Invalidate interrupts all dirty contexts. func (as *addressSpace) Invalidate() { as.mu.Lock() diff --git a/pkg/flipcall/packet_window_mmap_arm64.go b/pkg/sentry/platform/kvm/address_space_amd64.go index 87ad1a4a1..d11d38679 100644 --- a/pkg/flipcall/packet_window_mmap_arm64.go +++ b/pkg/sentry/platform/kvm/address_space_amd64.go @@ -1,4 +1,4 @@ -// Copyright 2020 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build arm64 +package kvm -package flipcall - -import "golang.org/x/sys/unix" - -// Return a memory mapping of the pwd in memory that can be shared outside the sandbox. -func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, unix.Errno) { - m, _, err := unix.RawSyscall6(unix.SYS_MMAP, 0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) - return m, err +// invalidate is the implementation for Invalidate. +func (as *addressSpace) invalidate() { + as.dirtySet.forEach(as.machine, func(c *vCPU) { + if c.active.get() == as { // If this happens to be active, + c.BounceToKernel() // ... force a kernel transition. + } + }) } diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/sentry/platform/kvm/address_space_arm64.go index 2bf21a2e7..fb954418b 100644 --- a/pkg/tcpip/transport/tcp/rcv_state.go +++ b/pkg/sentry/platform/kvm/address_space_arm64.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tcp +package kvm import ( - "time" + "gvisor.dev/gvisor/pkg/ring0" ) -// saveLastRcvdAckTime is invoked by stateify. -func (r *receiver) saveLastRcvdAckTime() unixTime { - return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} -} - -// loadLastRcvdAckTime is invoked by stateify. -func (r *receiver) loadLastRcvdAckTime(unix unixTime) { - r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +// invalidate is the implementation for Invalidate. +func (as *addressSpace) invalidate() { + bluepill(as.pageTables.Allocator.(*allocator).cpu) + ring0.FlushTlbAll() } diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 578852c3f..9e5c52923 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -25,29 +25,6 @@ import ( var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = unix.SIGILL - - // vcpuSErrBounce is the event of system error for bouncing KVM. - vcpuSErrBounce = kvmVcpuEvents{ - exception: exception{ - sErrPending: 1, - }, - } - - // vcpuSErrNMI is the event of system error to trigger sigbus. - vcpuSErrNMI = kvmVcpuEvents{ - exception: exception{ - sErrPending: 1, - sErrHasEsr: 1, - sErrEsr: _ESR_ELx_SERR_NMI, - }, - } - - // vcpuExtDabt is the event of ext_dabt. - vcpuExtDabt = kvmVcpuEvents{ - exception: exception{ - extDabtPending: 1, - }, - } ) // getTLS returns the value of TPIDR_EL0 register. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 07fc4f216..f105fdbd0 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -80,11 +80,18 @@ func getHypercallID(addr uintptr) int { // //go:nosplit func bluepillStopGuest(c *vCPU) { + // vcpuSErrBounce is the event of system error for bouncing KVM. + vcpuSErrBounce := &kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + }, + } + if _, _, errno := unix.RawSyscall( // escapes: no. unix.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { + uintptr(unsafe.Pointer(vcpuSErrBounce))); errno != 0 { throw("bounce sErr injection failed") } } @@ -93,12 +100,21 @@ func bluepillStopGuest(c *vCPU) { // //go:nosplit func bluepillSigBus(c *vCPU) { + // vcpuSErrNMI is the event of system error to trigger sigbus. + vcpuSErrNMI := &kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 1, + sErrEsr: _ESR_ELx_SERR_NMI, + }, + } + // Host must support ARM64_HAS_RAS_EXTN. if _, _, errno := unix.RawSyscall( // escapes: no. unix.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { + uintptr(unsafe.Pointer(vcpuSErrNMI))); errno != 0 { if errno == unix.EINVAL { throw("No ARM64_HAS_RAS_EXTN feature in host.") } @@ -110,11 +126,18 @@ func bluepillSigBus(c *vCPU) { // //go:nosplit func bluepillExtDabt(c *vCPU) { + // vcpuExtDabt is the event of ext_dabt. + vcpuExtDabt := &kvmVcpuEvents{ + exception: exception{ + extDabtPending: 1, + }, + } + if _, _, errno := unix.RawSyscall( // escapes: no. unix.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 { + uintptr(unsafe.Pointer(vcpuExtDabt))); errno != 0 { throw("ext_dabt injection failed") } } diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go index f4d4473a8..183e741ea 100644 --- a/pkg/sentry/platform/kvm/context.go +++ b/pkg/sentry/platform/kvm/context.go @@ -17,6 +17,7 @@ package kvm import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" pkgcontext "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" @@ -32,15 +33,15 @@ type context struct { // machine is the parent machine, and is immutable. machine *machine - // info is the arch.SignalInfo cached for this context. - info arch.SignalInfo + // info is the linux.SignalInfo cached for this context. + info linux.SignalInfo // interrupt is the interrupt context. interrupt interrupt.Forwarder } // Switch runs the provided context in the given address space. -func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, hostarch.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*linux.SignalInfo, hostarch.AccessType, error) { as := mm.AddressSpace() localAS := as.(*addressSpace) diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index b8dd1e4a5..b1cab89a0 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -19,6 +19,7 @@ package kvm import ( "testing" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -30,7 +31,7 @@ func TestSegments(t *testing.T) { applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestSegments(regs) for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -55,7 +56,7 @@ func stmxcsr(addr *uint32) func TestMXCSR(t *testing.T) { applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo switchOpts := ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index ceff09a60..fe570aff9 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -22,6 +22,7 @@ import ( "time" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" @@ -157,7 +158,7 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func func TestApplicationSyscall(t *testing.T) { applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -171,7 +172,7 @@ func TestApplicationSyscall(t *testing.T) { return false }) applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -188,7 +189,7 @@ func TestApplicationSyscall(t *testing.T) { func TestApplicationFault(t *testing.T) { applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -203,7 +204,7 @@ func TestApplicationFault(t *testing.T) { }) applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -221,7 +222,7 @@ func TestRegistersSyscall(t *testing.T) { applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestRegs(regs) // Fill values for all registers. for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -244,7 +245,7 @@ func TestRegistersFault(t *testing.T) { applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestRegs(regs) // Fill values for all registers. for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -270,7 +271,7 @@ func TestBounce(t *testing.T) { time.Sleep(time.Millisecond) c.BounceToKernel() }() - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -285,7 +286,7 @@ func TestBounce(t *testing.T) { time.Sleep(time.Millisecond) c.BounceToKernel() }() - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -317,7 +318,7 @@ func TestBounceStress(t *testing.T) { c.BounceToKernel() }() randomSleep() - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -338,7 +339,7 @@ func TestInvalidate(t *testing.T) { applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, &data) // Read legitimate value. for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -353,7 +354,7 @@ func TestInvalidate(t *testing.T) { // Unmap the page containing data & invalidate. pt.Unmap(hostarch.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(hostarch.PageSize-1)), hostarch.PageSize) for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -371,13 +372,13 @@ func TestInvalidate(t *testing.T) { } // IsFault returns true iff the given signal represents a fault. -func IsFault(err error, si *arch.SignalInfo) bool { +func IsFault(err error, si *linux.SignalInfo) bool { return err == platform.ErrContextSignal && si.Signo == int32(unix.SIGSEGV) } func TestEmptyAddressSpace(t *testing.T) { applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -391,7 +392,7 @@ func TestEmptyAddressSpace(t *testing.T) { return false }) applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -467,7 +468,7 @@ func BenchmarkApplicationSyscall(b *testing.B) { a int // Count for ErrContextInterrupt. ) applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -504,7 +505,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { a int ) applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 9a2337654..7a10fd812 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -23,11 +23,11 @@ import ( "runtime/debug" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" @@ -264,10 +264,10 @@ func (c *vCPU) setSystemTime() error { // nonCanonical generates a canonical address return. // //go:nosplit -func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { - *info = arch.SignalInfo{ +func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { + *info = linux.SignalInfo{ Signo: signal, - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } info.SetAddr(addr) // Include address. return hostarch.NoAccess, platform.ErrContextSignal @@ -276,7 +276,7 @@ func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.Ac // fault generates an appropriate fault return. // //go:nosplit -func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { bluepill(c) // Probably no-op, but may not be. faultAddr := ring0.ReadCR2() code, user := c.ErrorCode() @@ -287,7 +287,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, return hostarch.NoAccess, platform.ErrContextInterrupt } // Reset the pointed SignalInfo. - *info = arch.SignalInfo{Signo: signal} + *info = linux.SignalInfo{Signo: signal} info.SetAddr(uint64(faultAddr)) accessType := hostarch.AccessType{ Read: code&(1<<1) == 0, @@ -325,7 +325,7 @@ func prefaultFloatingPointState(data *fpu.State) { } // SwitchToUser unpacks architectural-details. -func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) { return nonCanonical(regs.Rip, int32(unix.SIGSEGV), info) @@ -371,7 +371,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return c.fault(int32(unix.SIGSEGV), info) case ring0.Debug, ring0.Breakpoint: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGTRAP), Code: 1, // TRAP_BRKPT (breakpoint). } @@ -383,9 +383,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) ring0.BoundRangeExceeded, ring0.InvalidTSS, ring0.StackSegmentFault: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGSEGV), - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } info.SetAddr(switchOpts.Registers.Rip) // Include address. if vector == ring0.GeneralProtectionFault { @@ -397,7 +397,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.AccessType{}, platform.ErrContextSignal case ring0.InvalidOpcode: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGILL), Code: 1, // ILL_ILLOPC (illegal opcode). } @@ -405,7 +405,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.AccessType{}, platform.ErrContextSignal case ring0.DivideByZero: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGFPE), Code: 1, // FPE_INTDIV (divide by zero). } @@ -413,7 +413,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.AccessType{}, platform.ErrContextSignal case ring0.Overflow: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGFPE), Code: 2, // FPE_INTOVF (integer overflow). } @@ -422,7 +422,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.X87FloatingPointException, ring0.SIMDFloatingPointException: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGFPE), Code: 7, // FPE_FLTINV (invalid operation). } @@ -433,7 +433,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.NoAccess, platform.ErrContextInterrupt case ring0.AlignmentCheck: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGBUS), Code: 2, // BUS_ADRERR (physical address does not exist). } @@ -469,7 +469,7 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { } func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { - // Map all the executible regions so that all the entry functions + // Map all the executable regions so that all the entry functions // are mapped in the upper half. applyVirtualRegions(func(vr virtualRegion) { if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { @@ -485,7 +485,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { pageTable.Map( hostarch.Addr(ring0.KernelStartAddress|r.virtual), r.length, - pagetables.MapOpts{AccessType: hostarch.Execute}, + pagetables.MapOpts{AccessType: hostarch.Execute, Global: true}, physical) } }) @@ -498,7 +498,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { pageTable.Map( hostarch.Addr(ring0.KernelStartAddress|start), regionLen, - pagetables.MapOpts{AccessType: hostarch.ReadWrite}, + pagetables.MapOpts{AccessType: hostarch.ReadWrite, Global: true}, physical) } } diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 8926b1d9f..edaccf9bc 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -21,10 +21,10 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ) @@ -126,10 +126,10 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { // nonCanonical generates a canonical address return. // //go:nosplit -func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { - *info = arch.SignalInfo{ +func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { + *info = linux.SignalInfo{ Signo: signal, - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } info.SetAddr(addr) // Include address. return hostarch.NoAccess, platform.ErrContextSignal @@ -157,7 +157,7 @@ func isWriteFault(code uint64) bool { // fault generates an appropriate fault return. // //go:nosplit -func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { bluepill(c) // Probably no-op, but may not be. faultAddr := c.GetFaultAddr() code, user := c.ErrorCode() @@ -170,7 +170,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, } // Reset the pointed SignalInfo. - *info = arch.SignalInfo{Signo: signal} + *info = linux.SignalInfo{Signo: signal} info.SetAddr(uint64(faultAddr)) ret := code & _ESR_ELx_FSC diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 92edc992b..f6aa519b1 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -23,10 +23,10 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" @@ -140,22 +140,15 @@ func (c *vCPU) initArchState() error { // vbar_el1 reg.id = _KVM_ARM64_REGS_VBAR_EL1 - - fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() - offset := fromLocation & (1<<11 - 1) - if offset != 0 { - offset = 1<<11 - offset - } - - toLocation := fromLocation + offset - data = uint64(ring0.KernelStartAddress | toLocation) + vectorLocation := reflect.ValueOf(ring0.Vectors).Pointer() + data = uint64(ring0.KernelStartAddress | vectorLocation) if err := c.setOneRegister(®); err != nil { return err } // Use the address of the exception vector table as // the MMIO address base. - arm64HypercallMMIOBase = toLocation + arm64HypercallMMIOBase = vectorLocation // Initialize the PCID database. if hasGuestPCID { @@ -272,7 +265,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { } // SwitchToUser unpacks architectural-details. -func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { return nonCanonical(regs.Pc, int32(unix.SIGSEGV), info) @@ -319,14 +312,14 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.El0SyncUndef: return c.fault(int32(unix.SIGILL), info) case ring0.El0SyncDbg: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGTRAP), Code: 1, // TRAP_BRKPT (breakpoint). } info.SetAddr(switchOpts.Registers.Pc) // Include address. return hostarch.AccessType{}, platform.ErrContextSignal case ring0.El0SyncSpPc: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGBUS), Code: 2, // BUS_ADRERR (physical address does not exist). } diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index ef7814a6f..a26bc2316 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -195,8 +195,8 @@ type Context interface { // - nil: The Context invoked a system call. // // - ErrContextSignal: The Context was interrupted by a signal. The - // returned *arch.SignalInfo contains information about the signal. If - // arch.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType + // returned *linux.SignalInfo contains information about the signal. If + // linux.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType // contains the access type of the triggering fault. The caller owns // the returned SignalInfo. // @@ -207,7 +207,7 @@ type Context interface { // concurrent call to Switch(). // // - ErrContextCPUPreempted: See the definition of that error for details. - Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) + Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error) // PullFullState() pulls a full state of the application thread. // diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 828458ce2..319b0cf1d 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -73,7 +73,7 @@ var ( type context struct { // signalInfo is the signal info, if and when a signal is received. - signalInfo arch.SignalInfo + signalInfo linux.SignalInfo // interrupt is the interrupt context. interrupt interrupt.Forwarder @@ -96,7 +96,7 @@ type context struct { } // Switch runs the provided context in the given address space. -func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error) { as := mm.AddressSpace() s := as.(*subprocess) isSyscall := s.switchToApp(c, ac) diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index facb96011..cc93396a9 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -101,7 +101,7 @@ func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) erro } // getSignalInfo retrieves information about the signal that caused the stop. -func (t *thread) getSignalInfo(si *arch.SignalInfo) error { +func (t *thread) getSignalInfo(si *linux.SignalInfo) error { _, _, errno := unix.RawSyscall6( unix.SYS_PTRACE, unix.PTRACE_GETSIGINFO, diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 9c73a725a..0931795c5 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -20,6 +20,7 @@ import ( "runtime" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" @@ -524,7 +525,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // Check for interrupts, and ensure that future interrupts will signal t. if !c.interrupt.Enable(t) { // Pending interrupt; simulate. - c.signalInfo = arch.SignalInfo{Signo: int32(platform.SignalInterrupt)} + c.signalInfo = linux.SignalInfo{Signo: int32(platform.SignalInterrupt)} return false } defer c.interrupt.Disable() diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 9252c0bd7..90b1ead56 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -155,7 +155,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) { // // Note that this should only be called after verifying that the signalInfo has // been generated by the kernel. -func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) { +func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) { if linux.Signal(signalInfo.Signo) == linux.SIGSYS { signalInfo.Signo = int32(linux.SIGSEGV) diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index c0cbc0686..e4257e3bf 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -138,7 +138,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) { // // Note that this should only be called after verifying that the signalInfo has // been generated by the kernel. -func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) { +func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) { if linux.Signal(signalInfo.Signo) == linux.SIGSYS { signalInfo.Signo = int32(linux.SIGSEGV) diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index d6a2fbe34..3fe5c6770 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -21,25 +21,16 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in -// pkg/sentry/arch. -type sigaction struct { - handler uintptr - flags uint64 - restorer uintptr - mask uint64 -} - // IgnoreChildStop sets the SA_NOCLDSTOP flag, causing child processes to not // generate SIGCHLD when they stop. func IgnoreChildStop() error { - var sa sigaction + var sa linux.SigAction // Get the existing signal handler information, and set the flag. if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), 0, uintptr(unsafe.Pointer(&sa)), linux.SignalSetSize, 0, 0); e != 0 { return e } - sa.flags |= linux.SA_NOCLDSTOP + sa.Flags |= linux.SA_NOCLDSTOP if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 { return e } diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 2029e7cf4..e1d310b1b 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/abi/linux", "//pkg/bits", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 235b9c306..64958b6ec 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -473,17 +474,17 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) for i := 0; i < len(buf); { if i+linux.SizeOfControlMessageHeader > len(buf) { - return cmsgs, syserror.EINVAL + return cmsgs, linuxerr.EINVAL } var h linux.ControlMessageHeader h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) if h.Length < uint64(linux.SizeOfControlMessageHeader) { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } if h.Length > uint64(len(buf)-i) { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } i += linux.SizeOfControlMessageHeader @@ -497,7 +498,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) numRights := rightsSize / linux.SizeOfControlMessageRight if len(fds)+numRights > linux.SCM_MAX_FD { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { @@ -508,7 +509,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } var creds linux.ControlMessageCredentials @@ -522,7 +523,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } var ts linux.Timeval ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) @@ -532,13 +533,13 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) default: // Unknown message type. - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } case linux.SOL_IP: switch h.Type { case linux.IP_TOS: if length < linux.SizeOfControlMessageTOS { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } cmsgs.IP.HasTOS = true var tos primitive.Uint8 @@ -548,7 +549,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } cmsgs.IP.HasIPPacketInfo = true @@ -561,7 +562,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr @@ -570,7 +571,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 if length < errCmsg.SizeBytes() { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) @@ -578,13 +579,13 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) i += bits.AlignUp(length, width) default: - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } case linux.SOL_IPV6: switch h.Type { case linux.IPV6_TCLASS: if length < linux.SizeOfControlMessageTClass { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } cmsgs.IP.HasTClass = true var tclass primitive.Uint32 @@ -595,7 +596,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr @@ -604,7 +605,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 if length < errCmsg.SizeBytes() { - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) @@ -612,10 +613,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) i += bits.AlignUp(length, width) default: - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } default: - return socket.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, linuxerr.EINVAL } } diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 3c6511ead..3950caa0f 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fdnotifier", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 52ae4bc9c..38cb2c99c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" @@ -67,23 +68,6 @@ type socketOperations struct { socketOpsCommon } -// socketOpsCommon contains the socket operations common to VFS1 and VFS2. -// -// +stateify savable -type socketOpsCommon struct { - socket.SendReceiveTimeout - - family int // Read-only. - stype linux.SockType // Read-only. - protocol int // Read-only. - queue waiter.Queue - - // fd is the host socket fd. It must have O_NONBLOCK, so that operations - // will return EWOULDBLOCK instead of blocking on the host. This allows us to - // handle blocking behavior independently in the sentry. - fd int -} - var _ = socket.Socket(&socketOperations{}) func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { @@ -103,29 +87,6 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil } -// Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release(context.Context) { - fdnotifier.RemoveFD(int32(s.fd)) - unix.Close(s.fd) -} - -// Readiness implements waiter.Waitable.Readiness. -func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { - return fdnotifier.NonBlockingPoll(int32(s.fd), mask) -} - -// EventRegister implements waiter.Waitable.EventRegister. -func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - s.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(s.fd)) -} - -// EventUnregister implements waiter.Waitable.EventUnregister. -func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { - s.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(s.fd)) -} - // Ioctl implements fs.FileOperations.Ioctl. func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { return ioctl(ctx, s.fd, io, args) @@ -177,6 +138,96 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), err } +// Socket implements socket.Provider.Socket. +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { + // Check that we are using the host network stack. + stack := t.NetworkContext() + if stack == nil { + return nil, nil + } + if _, ok := stack.(*Stack); !ok { + return nil, nil + } + + // Only accept TCP and UDP. + stype := stypeflags & linux.SOCK_TYPE_MASK + switch stype { + case unix.SOCK_STREAM: + switch protocol { + case 0, unix.IPPROTO_TCP: + // ok + default: + return nil, nil + } + case unix.SOCK_DGRAM: + switch protocol { + case 0, unix.IPPROTO_UDP: + // ok + default: + return nil, nil + } + default: + return nil, nil + } + + // Conservatively ignore all flags specified by the application and add + // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 + // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. + fd, err := unix.Socket(p.family, int(stype)|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0) + if err != nil { + return nil, syserr.FromError(err) + } + return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&unix.SOCK_NONBLOCK != 0) +} + +// Pair implements socket.Provider.Pair. +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { + // Not supported by AF_INET/AF_INET6. + return nil, nil, nil +} + +// LINT.ThenChange(./socket_vfs2.go) + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { + socket.SendReceiveTimeout + + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + queue waiter.Queue + + // fd is the host socket fd. It must have O_NONBLOCK, so that operations + // will return EWOULDBLOCK instead of blocking on the host. This allows us to + // handle blocking behavior independently in the sentry. + fd int +} + +// Release implements fs.FileOperations.Release. +func (s *socketOpsCommon) Release(context.Context) { + fdnotifier.RemoveFD(int32(s.fd)) + unix.Close(s.fd) +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { + return fdnotifier.NonBlockingPoll(int32(s.fd), mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.queue.EventRegister(e, mask) + fdnotifier.UpdateFD(int32(s.fd)) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { + s.queue.EventUnregister(e) + fdnotifier.UpdateFD(int32(s.fd)) +} + // Connect implements socket.Socket.Connect. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { if len(sockaddr) > sizeofSockaddr { @@ -596,6 +647,17 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b return 0, syserr.ErrInvalidArgument } + // If the src is zero-length, call SENDTO directly with a null buffer in + // order to generate poll/epoll notifications. + if src.NumBytes() == 0 { + sysflags := flags | unix.MSG_DONTWAIT + n, _, errno := unix.Syscall6(unix.SYS_SENDTO, uintptr(s.fd), 0, 0, uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to))) + if errno != 0 { + return 0, syserr.FromError(errno) + } + return int(n), nil + } + space := uint64(control.CmsgsSpace(t, controlMessages)) if space > maxControlLen { space = maxControlLen @@ -653,7 +715,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } if ch != nil { if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -709,56 +771,6 @@ type socketProvider struct { family int } -// Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { - // Check that we are using the host network stack. - stack := t.NetworkContext() - if stack == nil { - return nil, nil - } - if _, ok := stack.(*Stack); !ok { - return nil, nil - } - - // Only accept TCP and UDP. - stype := stypeflags & linux.SOCK_TYPE_MASK - switch stype { - case unix.SOCK_STREAM: - switch protocol { - case 0, unix.IPPROTO_TCP: - // ok - default: - return nil, nil - } - case unix.SOCK_DGRAM: - switch protocol { - case 0, unix.IPPROTO_UDP: - // ok - default: - return nil, nil - } - default: - return nil, nil - } - - // Conservatively ignore all flags specified by the application and add - // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 - // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. - fd, err := unix.Socket(p.family, int(stype)|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0) - if err != nil { - return nil, syserr.FromError(err) - } - return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&unix.SOCK_NONBLOCK != 0) -} - -// Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { - // Not supported by AF_INET/AF_INET6. - return nil, nil, nil -} - -// LINT.ThenChange(./socket_vfs2.go) - func init() { for _, family := range []int{unix.AF_INET, unix.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index d3be2d825..86dc879d5 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -67,7 +67,23 @@ func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArgument AddressSpaceActive: true, }) return 0, err - + case unix.SIOCGIFFLAGS: + cc := &usermem.IOCopyContext{ + Ctx: ctx, + IO: io, + Opts: usermem.IOOpts{ + AddressSpaceActive: true, + }, + } + var ifr linux.IFReq + if _, err := ifr.CopyIn(cc, args[2].Pointer()); err != nil { + return 0, err + } + if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&ifr))); errno != 0 { + return 0, translateIOSyscallError(errno) + } + _, err := ifr.CopyOut(cc, args[2].Pointer()) + return 0, err default: return 0, syserror.ENOTTY } diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 61b2c9755..608474fa1 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -25,6 +25,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 6fc7781ad..3f1b4a17b 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,20 +19,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// TODO(gvisor.dev/issue/170): The following per-matcher params should be -// supported: -// - Table name -// - Match size -// - User size -// - Hooks -// - Proto -// - Family - // matchMaker knows how to (un)marshal the matcher named name(). type matchMaker interface { // name is the matcher name as stored in the xt_entry_match struct. @@ -43,7 +35,7 @@ type matchMaker interface { // unmarshal converts from the ABI matcher struct to an // stack.Matcher. - unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) + unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) } type matcher interface { @@ -94,12 +86,12 @@ func marshalEntryMatch(name string, data []byte) []byte { return buf } -func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { +func unmarshalMatcher(task *kernel.Task, match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { matchMaker, ok := matchMakers[match.Name.String()] if !ok { return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String()) } - return matchMaker.unmarshal(buf, filter) + return matchMaker.unmarshal(task, buf, filter) } // targetMaker knows how to (un)marshal a target. Once registered, diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index cb78ef60b..d8bd86292 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -123,7 +124,7 @@ func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTG return entries, info } -func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { +func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { nflog("set entries: setting entries in table %q", replace.Name.String()) // Convert input into a list of rules and their offsets. @@ -148,23 +149,19 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): We should support more IPTIP - // filtering fields. filter, err := filterFromIPTIP(entry.IP) if err != nil { nflog("bad iptip: %v", err) return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Matchers and targets can specify - // that they only work for certain protocols, hooks, tables. // Get matchers. matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry if len(optVal) < int(matchersSize) { nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - matchers, err := parseMatchers(filter, optVal[:matchersSize]) + matchers, err := parseMatchers(task, filter, optVal[:matchersSize]) if err != nil { nflog("failed to parse matchers: %v", err) return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 5cb7fe4aa..c68230847 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -126,7 +127,7 @@ func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6T return entries, info } -func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { +func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { nflog("set entries: setting entries in table %q", replace.Name.String()) // Convert input into a list of rules and their offsets. @@ -151,23 +152,19 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): We should support more IPTIP - // filtering fields. filter, err := filterFromIP6TIP(entry.IPv6) if err != nil { nflog("bad iptip: %v", err) return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Matchers and targets can specify - // that they only work for certain protocols, hooks, tables. // Get matchers. matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry if len(optVal) < int(matchersSize) { nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - matchers, err := parseMatchers(filter, optVal[:matchersSize]) + matchers, err := parseMatchers(task, filter, optVal[:matchersSize]) if err != nil { nflog("failed to parse matchers: %v", err) return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f42d73178..e3eade180 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -58,8 +58,8 @@ var nameToID = map[string]stack.TableID{ // DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for // compatibility with netfilter extensions. -func DefaultLinuxTables() *stack.IPTables { - tables := stack.DefaultTables() +func DefaultLinuxTables(seed uint32) *stack.IPTables { + tables := stack.DefaultTables(seed) tables.VisitTargets(func(oldTarget stack.Target) stack.Target { switch val := oldTarget.(type) { case *stack.AcceptTarget: @@ -174,13 +174,12 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. -func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { +func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] replace.UnmarshalBytes(replaceBuf) - // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { case filterTable: @@ -188,16 +187,16 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { case natTable: table = stack.EmptyNATTable() default: - nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) + nflog("unknown iptables table %q", replace.Name.String()) return syserr.ErrInvalidArgument } var err *syserr.Error var offsets map[uint32]int if ipv6 { - offsets, err = modifyEntries6(stk, optVal, &replace, &table) + offsets, err = modifyEntries6(task, stk, optVal, &replace, &table) } else { - offsets, err = modifyEntries4(stk, optVal, &replace, &table) + offsets, err = modifyEntries4(task, stk, optVal, &replace, &table) } if err != nil { return err @@ -272,7 +271,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { table.Rules[ruleIdx] = rule } - // TODO(gvisor.dev/issue/170): Support other chains. // Since we don't support FORWARD, yet, make sure all other chains point to // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { @@ -287,7 +285,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } } - // TODO(gvisor.dev/issue/170): Check the following conditions: + // TODO(gvisor.dev/issue/6167): Check the following conditions: // - There are no loops. // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. @@ -297,7 +295,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // parseMatchers parses 0 or more matchers from optVal. optVal should contain // only the matchers. -func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) { +func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) { nflog("set entries: parsing matchers of size %d", len(optVal)) var matchers []stack.Matcher for len(optVal) > 0 { @@ -321,13 +319,13 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, } // Parse the specific matcher. - matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize]) + matcher, err := unmarshalMatcher(task, match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize]) if err != nil { return nil, fmt.Errorf("failed to create matcher: %v", err) } matchers = append(matchers, matcher) - // TODO(gvisor.dev/issue/170): Check the revision field. + // TODO(gvisor.dev/issue/6167): Check the revision field. optVal = optVal[match.MatchSize:] } diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 60845cab3..6eff2ae65 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -40,8 +42,8 @@ func (ownerMarshaler) name() string { func (ownerMarshaler) marshal(mr matcher) []byte { matcher := mr.(*OwnerMatcher) iptOwnerInfo := linux.IPTOwnerInfo{ - UID: matcher.uid, - GID: matcher.gid, + UID: uint32(matcher.uid), + GID: uint32(matcher.gid), } // Support for UID and GID match. @@ -63,7 +65,7 @@ func (ownerMarshaler) marshal(mr matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { +func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfIPTOwnerInfo { return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf)) } @@ -72,11 +74,12 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) - nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) + nflog("parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher - owner.uid = matchData.UID - owner.gid = matchData.GID + creds := task.Credentials() + owner.uid = creds.UserNamespace.MapToKUID(auth.UID(matchData.UID)) + owner.gid = creds.UserNamespace.MapToKGID(auth.GID(matchData.GID)) // Check flags. if matchData.Match&linux.XT_OWNER_UID != 0 { @@ -97,8 +100,8 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // OwnerMatcher matches against a UID and/or GID. type OwnerMatcher struct { - uid uint32 - gid uint32 + uid auth.KUID + gid auth.KGID matchUID bool matchGID bool invertUID bool @@ -113,7 +116,6 @@ func (*OwnerMatcher) name() string { // Match implements Matcher.Match. func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // Support only for OUTPUT chain. - // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { return false, true } @@ -126,7 +128,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str var matches bool // Check for UID match. if om.matchUID { - if pkt.Owner.UID() == om.uid { + if auth.KUID(pkt.Owner.KUID()) == om.uid { matches = true } if matches == om.invertUID { @@ -137,7 +139,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str // Check for GID match. if om.matchGID { matches = false - if pkt.Owner.GID() == om.gid { + if auth.KGID(pkt.Owner.KGID()) == om.gid { matches = true } if matches == om.invertGID { diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index fa5456eee..ea56f39c1 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -331,7 +331,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Check if the flags are valid. // Also check if we need to map ports or IP. // For now, redirect target only supports destination port change. // Port range and IP range are not supported yet. @@ -340,7 +339,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Port range is not supported yet. if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) return nil, syserr.ErrInvalidArgument @@ -420,7 +418,6 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/3549): Check for other flags. // For now, redirect target only supports destination change. if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags) @@ -502,7 +499,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Port range is not supported yet. if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) return nil, syserr.ErrInvalidArgument @@ -594,7 +590,6 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { - // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: return &acceptTarget{stack.AcceptTarget{ diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 95bb9826e..e5b73a976 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -50,7 +51,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { +func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTTCP { return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf)) } @@ -95,8 +96,6 @@ func (*TCPMatcher) name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved - // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: netHeader := header.IPv4(pkt.NetworkHeader().View()) diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index fb8be27e6..aa72ee70c 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -50,7 +51,7 @@ func (udpMarshaler) marshal(mr matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { +func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTUDP { return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf)) } @@ -92,8 +93,6 @@ func (*UDPMatcher) name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved - // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: netHeader := header.IPv4(pkt.NetworkHeader().View()) diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 64cd263da..ed85404da 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -14,8 +14,10 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/bits", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 280563d09..d53f23a9a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -20,7 +20,9 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -56,7 +58,7 @@ const ( maxSendBufferSize = 4 << 20 // 4MB ) -var errNoFilter = syserr.New("no filter attached", linux.ENOENT) +var errNoFilter = syserr.New("no filter attached", errno.ENOENT) // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() @@ -558,7 +560,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 9561b7c25..e828982eb 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -19,7 +19,9 @@ go_library( ], deps = [ "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/marshal", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0b64a24c3..11f75628c 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -36,7 +36,9 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" @@ -77,41 +79,59 @@ func mustCreateGauge(name, description string) *tcpip.StatCounter { // Metrics contains metrics exported by netstack. var Metrics = tcpip.Stats{ - UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."), - MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), - DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), + DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped at the transport layer."), + NICs: tcpip.NICStats{ + UnknownL3ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l3_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L3 protocol."), + UnknownL4ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l4_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L4 protocol."), + MalformedL4RcvdPackets: mustCreateMetric("/netstack/nic/malformed_l4_received_packets", "Number of packets received that failed L4 header parsing."), + Tx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/tx/packets", "Number of packets transmitted."), + Bytes: mustCreateMetric("/netstack/nic/tx/bytes", "Number of bytes transmitted."), + }, + Rx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/rx/packets", "Number of packets received."), + Bytes: mustCreateMetric("/netstack/nic/rx/bytes", "Number of bytes received."), + }, + DisabledRx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/disabled_rx/packets", "Number of packets received on disabled NICs."), + Bytes: mustCreateMetric("/netstack/nic/disabled_rx/bytes", "Number of bytes received on disabled NICs."), + }, + Neighbor: tcpip.NICNeighborStats{ + UnreachableEntryLookups: mustCreateMetric("/netstack/nic/neighbor/unreachable_entry_loopups", "Number of lookups performed on a neighbor entry in Unreachable state."), + }, + }, ICMP: tcpip.ICMPStats{ V4: tcpip.ICMPv4Stats{ PacketsSent: tcpip.ICMPv4SentPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received."), }, Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."), }, @@ -119,40 +139,40 @@ var Metrics = tcpip.Stats{ V6: tcpip.ICMPv6Stats{ PacketsSent: tcpip.ICMPv6SentPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."), Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."), @@ -163,23 +183,23 @@ var Metrics = tcpip.Stats{ IGMP: tcpip.IGMPStats{ PacketsSent: tcpip.IGMPSentPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent."), }, - Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped due to link layer errors."), }, PacketsReceived: tcpip.IGMPReceivedPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received."), }, - Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received that could not be parsed."), ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."), - Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received."), }, }, IP: tcpip.IPStats{ @@ -205,7 +225,8 @@ var Metrics = tcpip.Stats{ LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."), - PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not fit within the outgoing MTU."), + PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not be forwarded because they could not fit within the outgoing MTU."), + HostUnreachable: mustCreateMetric("/netstack/ip/forwarding/host_unreachable", "Number of IP packets received which could not be forwarded due to unresolvable next hop."), Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), }, }, @@ -270,7 +291,7 @@ const DefaultTTL = 64 const sizeOfInt32 int = 4 -var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) +var errStackType = syserr.New("expected but did not receive a netstack.Stack", errno.EINVAL) // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. @@ -1126,7 +1147,14 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, // TODO(b/64800844): Translate fields once they are added to // tcpip.TCPInfoOption. - info := linux.TCPInfo{} + info := linux.TCPInfo{ + State: uint8(v.State), + RTO: uint32(v.RTO / time.Microsecond), + RTT: uint32(v.RTT / time.Microsecond), + RTTVar: uint32(v.RTTVar / time.Microsecond), + SndSsthresh: v.SndSsthresh, + SndCwnd: v.SndCwnd, + } switch v.CcState { case tcpip.RTORecovery: info.CaState = linux.TCP_CA_Loss @@ -1137,11 +1165,6 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, case tcpip.Open: info.CaState = linux.TCP_CA_Open } - info.RTO = uint32(v.RTO / time.Microsecond) - info.RTT = uint32(v.RTT / time.Microsecond) - info.RTTVar = uint32(v.RTTVar / time.Microsecond) - info.SndSsthresh = v.SndSsthresh - info.SndCwnd = v.SndCwnd // In netstack reorderSeen is updated only when RACK is enabled. // We only track whether the reordering is seen, which is @@ -1777,11 +1800,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := hostarch.ByteOrder.Uint32(optVal) - - if v == 0 { - socket.SetSockOptEmitUnimplementedEvent(t, name) - } - ep.SocketOptions().SetOutOfBandInline(v != 0) return nil @@ -2089,10 +2107,10 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true) + return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true) case linux.IP6T_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. + log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported") return nil default: @@ -2332,10 +2350,10 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false) + return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false) case linux.IPT_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. + log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported") return nil case linux.IP_ADD_SOURCE_MEMBERSHIP, @@ -2792,7 +2810,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if n > 0 { return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil } - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) @@ -2860,7 +2878,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b // became available between when we last checked and when we setup // the notification. if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return int(total), syserr.ErrTryAgain } // handleIOError will consume errors from t.Block if needed. diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index eef5e6519..9d343b671 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" @@ -110,19 +111,19 @@ func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) { switch addr.Family { case linux.AF_INET: if len(addr.Addr) != header.IPv4AddressSize { - return protocolAddress, syserror.EINVAL + return protocolAddress, linuxerr.EINVAL } if addr.PrefixLen > header.IPv4AddressSize*8 { - return protocolAddress, syserror.EINVAL + return protocolAddress, linuxerr.EINVAL } protocol = ipv4.ProtocolNumber address = tcpip.Address(addr.Addr) case linux.AF_INET6: if len(addr.Addr) != header.IPv6AddressSize { - return protocolAddress, syserror.EINVAL + return protocolAddress, linuxerr.EINVAL } if addr.PrefixLen > header.IPv6AddressSize*8 { - return protocolAddress, syserror.EINVAL + return protocolAddress, linuxerr.EINVAL } protocol = ipv6.ProtocolNumber address = tcpip.Address(addr.Addr) diff --git a/pkg/sentry/socket/netstack/tun.go b/pkg/sentry/socket/netstack/tun.go index 288dd0c9e..e67fe9700 100644 --- a/pkg/sentry/socket/netstack/tun.go +++ b/pkg/sentry/socket/netstack/tun.go @@ -16,7 +16,7 @@ package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/tcpip/link/tun" ) @@ -40,8 +40,8 @@ func LinuxToTUNFlags(flags uint16) (tun.Flags, error) { // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) // when there is no sk_filter. See __tun_chr_ioctl() in // net/drivers/tun.c. - if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI) != 0 { - return tun.Flags{}, syserror.EINVAL + if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI|linux.IFF_ONE_QUEUE) != 0 { + return tun.Flags{}, linuxerr.EINVAL } return tun.Flags{ TUN: flags&linux.IFF_TUN != 0, diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 353f4ade0..f5da3c509 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -659,7 +659,6 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) @@ -749,7 +748,6 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index c9cbefb3a..5c3cdef6a 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -39,6 +39,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index db7b1affe..8ccdadae9 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -23,6 +23,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -518,7 +519,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -719,7 +720,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if total > 0 { err = nil } - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err) diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index c39e317ff..08a00a12f 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -17,6 +17,7 @@ package unix import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -236,7 +237,7 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()), Endpoint: bep, }) - if err == syserror.EEXIST { + if linuxerr.Equals(linuxerr.EEXIST, err) { return syserr.ErrAddressInUse } return syserr.FromError(err) diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 3e801182c..7f02807c5 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -20,7 +21,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/state/statefile", - "//pkg/syserror", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 167754537..e9d544f3d 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -20,6 +20,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -27,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" - "gvisor.dev/gvisor/pkg/syserror" ) var previousMetadata map[string]string @@ -88,7 +88,7 @@ func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Wat // ENOSPC is a state file error. This error can only come from // writing the state file, and not from fs.FileOperations.Fsync // because we wrap those in kernel.TaskSet.flushWritesToFiles. - if err == syserror.ENOSPC { + if linuxerr.Equals(linuxerr.ENOSPC, err) { err = ErrStateFile{err} } @@ -110,7 +110,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { +func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, timeReady chan struct{}, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -120,5 +120,5 @@ func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, c previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(ctx, r, n, clocks, vfsOpts) + return k.LoadFrom(ctx, r, timeReady, n, clocks, vfsOpts) } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 1fbbd133c..369541c7a 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -11,6 +11,7 @@ go_library( "futex.go", "linux64_amd64.go", "linux64_arm64.go", + "mmap.go", "open.go", "poll.go", "ptrace.go", @@ -35,7 +36,6 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/netlink", "//pkg/sentry/syscalls/linux", - "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/strace/clone.go b/pkg/sentry/strace/clone.go index ab1060426..bfb4d7f5c 100644 --- a/pkg/sentry/strace/clone.go +++ b/pkg/sentry/strace/clone.go @@ -15,98 +15,98 @@ package strace import ( - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" ) // CloneFlagSet is the set of clone(2) flags. var CloneFlagSet = abi.FlagSet{ { - Flag: unix.CLONE_VM, + Flag: linux.CLONE_VM, Name: "CLONE_VM", }, { - Flag: unix.CLONE_FS, + Flag: linux.CLONE_FS, Name: "CLONE_FS", }, { - Flag: unix.CLONE_FILES, + Flag: linux.CLONE_FILES, Name: "CLONE_FILES", }, { - Flag: unix.CLONE_SIGHAND, + Flag: linux.CLONE_SIGHAND, Name: "CLONE_SIGHAND", }, { - Flag: unix.CLONE_PTRACE, + Flag: linux.CLONE_PTRACE, Name: "CLONE_PTRACE", }, { - Flag: unix.CLONE_VFORK, + Flag: linux.CLONE_VFORK, Name: "CLONE_VFORK", }, { - Flag: unix.CLONE_PARENT, + Flag: linux.CLONE_PARENT, Name: "CLONE_PARENT", }, { - Flag: unix.CLONE_THREAD, + Flag: linux.CLONE_THREAD, Name: "CLONE_THREAD", }, { - Flag: unix.CLONE_NEWNS, + Flag: linux.CLONE_NEWNS, Name: "CLONE_NEWNS", }, { - Flag: unix.CLONE_SYSVSEM, + Flag: linux.CLONE_SYSVSEM, Name: "CLONE_SYSVSEM", }, { - Flag: unix.CLONE_SETTLS, + Flag: linux.CLONE_SETTLS, Name: "CLONE_SETTLS", }, { - Flag: unix.CLONE_PARENT_SETTID, + Flag: linux.CLONE_PARENT_SETTID, Name: "CLONE_PARENT_SETTID", }, { - Flag: unix.CLONE_CHILD_CLEARTID, + Flag: linux.CLONE_CHILD_CLEARTID, Name: "CLONE_CHILD_CLEARTID", }, { - Flag: unix.CLONE_DETACHED, + Flag: linux.CLONE_DETACHED, Name: "CLONE_DETACHED", }, { - Flag: unix.CLONE_UNTRACED, + Flag: linux.CLONE_UNTRACED, Name: "CLONE_UNTRACED", }, { - Flag: unix.CLONE_CHILD_SETTID, + Flag: linux.CLONE_CHILD_SETTID, Name: "CLONE_CHILD_SETTID", }, { - Flag: unix.CLONE_NEWUTS, + Flag: linux.CLONE_NEWUTS, Name: "CLONE_NEWUTS", }, { - Flag: unix.CLONE_NEWIPC, + Flag: linux.CLONE_NEWIPC, Name: "CLONE_NEWIPC", }, { - Flag: unix.CLONE_NEWUSER, + Flag: linux.CLONE_NEWUSER, Name: "CLONE_NEWUSER", }, { - Flag: unix.CLONE_NEWPID, + Flag: linux.CLONE_NEWPID, Name: "CLONE_NEWPID", }, { - Flag: unix.CLONE_NEWNET, + Flag: linux.CLONE_NEWNET, Name: "CLONE_NEWNET", }, { - Flag: unix.CLONE_IO, + Flag: linux.CLONE_IO, Name: "CLONE_IO", }, } diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index d66befe81..6ce1bb592 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -33,7 +33,7 @@ var linuxAMD64 = SyscallMap{ 6: makeSyscallInfo("lstat", Path, Stat), 7: makeSyscallInfo("poll", PollFDs, Hex, Hex), 8: makeSyscallInfo("lseek", Hex, Hex, Hex), - 9: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex), + 9: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex), 10: makeSyscallInfo("mprotect", Hex, Hex, Hex), 11: makeSyscallInfo("munmap", Hex, Hex), 12: makeSyscallInfo("brk", Hex), diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index 1a2d7d75f..ce5594301 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -246,7 +246,7 @@ var linuxARM64 = SyscallMap{ 219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex), 220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex), 221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector), - 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex), + 222: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex), 223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex), 224: makeSyscallInfo("swapon", Hex, Hex), 225: makeSyscallInfo("swapoff", Hex), diff --git a/pkg/sentry/strace/mmap.go b/pkg/sentry/strace/mmap.go new file mode 100644 index 000000000..0035be586 --- /dev/null +++ b/pkg/sentry/strace/mmap.go @@ -0,0 +1,92 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// ProtectionFlagSet represents the protection to mmap(2). +var ProtectionFlagSet = abi.FlagSet{ + { + Flag: linux.PROT_READ, + Name: "PROT_READ", + }, + { + Flag: linux.PROT_WRITE, + Name: "PROT_WRITE", + }, + { + Flag: linux.PROT_EXEC, + Name: "PROT_EXEC", + }, +} + +// MmapFlagSet is the set of mmap(2) flags. +var MmapFlagSet = abi.FlagSet{ + { + Flag: linux.MAP_SHARED, + Name: "MAP_SHARED", + }, + { + Flag: linux.MAP_PRIVATE, + Name: "MAP_PRIVATE", + }, + { + Flag: linux.MAP_FIXED, + Name: "MAP_FIXED", + }, + { + Flag: linux.MAP_ANONYMOUS, + Name: "MAP_ANONYMOUS", + }, + { + Flag: linux.MAP_GROWSDOWN, + Name: "MAP_GROWSDOWN", + }, + { + Flag: linux.MAP_DENYWRITE, + Name: "MAP_DENYWRITE", + }, + { + Flag: linux.MAP_EXECUTABLE, + Name: "MAP_EXECUTABLE", + }, + { + Flag: linux.MAP_LOCKED, + Name: "MAP_LOCKED", + }, + { + Flag: linux.MAP_NORESERVE, + Name: "MAP_NORESERVE", + }, + { + Flag: linux.MAP_POPULATE, + Name: "MAP_POPULATE", + }, + { + Flag: linux.MAP_NONBLOCK, + Name: "MAP_NONBLOCK", + }, + { + Flag: linux.MAP_STACK, + Name: "MAP_STACK", + }, + { + Flag: linux.MAP_HUGETLB, + Name: "MAP_HUGETLB", + }, +} diff --git a/pkg/sentry/strace/open.go b/pkg/sentry/strace/open.go index 5769360da..e7c7649f4 100644 --- a/pkg/sentry/strace/open.go +++ b/pkg/sentry/strace/open.go @@ -15,61 +15,61 @@ package strace import ( - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" ) // OpenMode represents the mode to open(2) a file. var OpenMode = abi.ValueSet{ - unix.O_RDWR: "O_RDWR", - unix.O_WRONLY: "O_WRONLY", - unix.O_RDONLY: "O_RDONLY", + linux.O_RDWR: "O_RDWR", + linux.O_WRONLY: "O_WRONLY", + linux.O_RDONLY: "O_RDONLY", } // OpenFlagSet is the set of open(2) flags. var OpenFlagSet = abi.FlagSet{ { - Flag: unix.O_APPEND, + Flag: linux.O_APPEND, Name: "O_APPEND", }, { - Flag: unix.O_ASYNC, + Flag: linux.O_ASYNC, Name: "O_ASYNC", }, { - Flag: unix.O_CLOEXEC, + Flag: linux.O_CLOEXEC, Name: "O_CLOEXEC", }, { - Flag: unix.O_CREAT, + Flag: linux.O_CREAT, Name: "O_CREAT", }, { - Flag: unix.O_DIRECT, + Flag: linux.O_DIRECT, Name: "O_DIRECT", }, { - Flag: unix.O_DIRECTORY, + Flag: linux.O_DIRECTORY, Name: "O_DIRECTORY", }, { - Flag: unix.O_EXCL, + Flag: linux.O_EXCL, Name: "O_EXCL", }, { - Flag: unix.O_NOATIME, + Flag: linux.O_NOATIME, Name: "O_NOATIME", }, { - Flag: unix.O_NOCTTY, + Flag: linux.O_NOCTTY, Name: "O_NOCTTY", }, { - Flag: unix.O_NOFOLLOW, + Flag: linux.O_NOFOLLOW, Name: "O_NOFOLLOW", }, { - Flag: unix.O_NONBLOCK, + Flag: linux.O_NONBLOCK, Name: "O_NONBLOCK", }, { @@ -77,18 +77,22 @@ var OpenFlagSet = abi.FlagSet{ Name: "O_PATH", }, { - Flag: unix.O_SYNC, + Flag: linux.O_SYNC, Name: "O_SYNC", }, { - Flag: unix.O_TRUNC, + Flag: linux.O_TMPFILE, + Name: "O_TMPFILE", + }, + { + Flag: linux.O_TRUNC, Name: "O_TRUNC", }, } func open(val uint64) string { - s := OpenMode.Parse(val & unix.O_ACCMODE) - if flags := OpenFlagSet.Parse(val &^ unix.O_ACCMODE); flags != "" { + s := OpenMode.Parse(val & linux.O_ACCMODE) + if flags := OpenFlagSet.Parse(val &^ linux.O_ACCMODE); flags != "" { s += "|" + flags } return s diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go index e5b379a20..5afc9525b 100644 --- a/pkg/sentry/strace/signal.go +++ b/pkg/sentry/strace/signal.go @@ -130,8 +130,8 @@ func sigAction(t *kernel.Task, addr hostarch.Addr) string { return "null" } - sa, err := t.CopyInSignalAct(addr) - if err != nil { + var sa linux.SigAction + if _, err := sa.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error copying sigaction: %v)", addr, err) } diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index ec5d5f846..af7088847 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -489,6 +489,10 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize))) case SelectFDSet: output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) + case MmapProt: + output = append(output, ProtectionFlagSet.Parse(uint64(args[arg].Uint()))) + case MmapFlags: + output = append(output, MmapFlagSet.Parse(uint64(args[arg].Uint()))) case Oct: output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8)) case Hex: diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 7e69b9279..5893443a7 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -238,6 +238,12 @@ const ( // EpollEvents is an array of struct epoll_event. It is the events // argument in epoll_wait(2)/epoll_pwait(2). EpollEvents + + // MmapProt is the protection argument in mmap(2). + MmapProt + + // MmapFlags is the flags argument in mmap(2). + MmapFlags ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD index b8d1bd415..f2c55588f 100644 --- a/pkg/sentry/syscalls/BUILD +++ b/pkg/sentry/syscalls/BUILD @@ -11,6 +11,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/errors/linuxerr", "//pkg/sentry/arch", "//pkg/sentry/kernel", "//pkg/sentry/kernel/epoll", diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index 3b4d79889..02debfc7e 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -163,7 +164,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return nil, nil } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 408a6c422..a2f612f45 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -64,6 +64,7 @@ go_library( "//pkg/abi/linux", "//pkg/bpf", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/log", "//pkg/marshal", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 6eabfd219..165922332 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -94,13 +95,13 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er if errno, ok := syserror.TranslateError(errOrig); ok { translatedErr = errno } - switch translatedErr { - case io.EOF: + switch { + case translatedErr == io.EOF: // EOF is always consumed. If this is a partial read/write // (result != 0), the application will see that, otherwise // they will see 0. return true, nil - case syserror.EFBIG: + case linuxerr.Equals(linuxerr.EFBIG, translatedErr): t := kernel.TaskFromContext(ctx) if t == nil { panic("I/O error should only occur from a context associated with a Task") @@ -113,7 +114,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er // Simultaneously send a SIGXFSZ per setrlimit(2). t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t)) return true, syserror.EFBIG - case syserror.EINTR: + case linuxerr.Equals(linuxerr.EINTR, translatedErr): // The syscall was interrupted. Return nil if it completed // partially, otherwise return the error code that the syscall // needs (to indicate to the kernel what it should do). @@ -128,21 +129,21 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er return true, errOrig } - switch translatedErr { - case syserror.EINTR: + switch { + case linuxerr.Equals(linuxerr.EINTR, translatedErr): // Syscall interrupted, but completed a partial // read/write. Like ErrWouldBlock, since we have a // partial read/write, we consume the error and return // the partial result. return true, nil - case syserror.EFAULT: + case linuxerr.Equals(linuxerr.EFAULT, translatedErr): // EFAULT is only shown the user if nothing was // read/written. If we read something (this case), they see // a partial read/write. They will then presumably try again // with an incremented buffer, which will EFAULT with // result == 0. return true, nil - case syserror.EPIPE: + case linuxerr.Equals(linuxerr.EPIPE, translatedErr): // Writes to a pipe or socket will return EPIPE if the other // side is gone. The partial write is returned. EPIPE will be // returned on the next call. @@ -150,15 +151,17 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er // TODO(gvisor.dev/issue/161): In some cases SIGPIPE should // also be sent to the application. return true, nil - case syserror.ENOSPC: + case linuxerr.Equals(linuxerr.ENOSPC, translatedErr): // Similar to EPIPE. Return what we wrote this time, and let // ENOSPC be returned on the next call. return true, nil - case syserror.ECONNRESET, syserror.ETIMEDOUT: + case linuxerr.Equals(linuxerr.ECONNRESET, translatedErr): + fallthrough + case linuxerr.Equals(linuxerr.ETIMEDOUT, translatedErr): // For TCP sendfile connections, we may have a reset or timeout. But we // should just return n as the result. return true, nil - case syserror.EWOULDBLOCK: + case linuxerr.Equals(linuxerr.EWOULDBLOCK, translatedErr): // Syscall would block, but completed a partial read/write. // This case should only be returned by IssueIO for nonblocking // files. Since we have a partial read/write, we consume diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 090c5ffcb..1732064ef 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -18,6 +18,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -187,7 +188,7 @@ var AMD64 = &kernel.SyscallTable{ 132: syscalls.Supported("utime", Utime), 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), - 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 135: syscalls.ErrorWithEvent("personality", linuxerr.EINVAL, "Unable to change personality.", nil), 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), @@ -521,7 +522,7 @@ var ARM64 = &kernel.SyscallTable{ 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), 90: syscalls.Supported("capget", Capget), 91: syscalls.Supported("capset", Capset), - 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 92: syscalls.ErrorWithEvent("personality", linuxerr.EINVAL, "Unable to change personality.", nil), 93: syscalls.Supported("exit", Exit), 94: syscalls.Supported("exit_group", ExitGroup), 95: syscalls.Supported("waitid", Waitid), diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go index e8c2d8f9e..9dea78085 100644 --- a/pkg/sentry/syscalls/linux/sigset.go +++ b/pkg/sentry/syscalls/linux/sigset.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -29,7 +30,7 @@ import ( // syscalls are moved into this package, then they can be unexported. func CopyInSigSet(t *kernel.Task, sigSetAddr hostarch.Addr, size uint) (linux.SignalSet, error) { if size != linux.SignalSetSize { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } b := t.CopyScratchBuffer(8) if _, err := t.CopyInBytes(sigSetAddr, b); err != nil { diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 70e8569a8..a93fc635b 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -17,6 +17,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -42,7 +43,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, err } if idIn != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } id, err := t.MemoryManager().NewAIOContext(t, uint32(nrEvents)) @@ -66,7 +67,7 @@ func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ctx := t.MemoryManager().DestroyAIOContext(t, id) if ctx == nil { // Does not exist. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Drain completed requests amd wait for pending requests until there are no @@ -97,12 +98,12 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Sanity check arguments. if minEvents < 0 || minEvents > events { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ctx, ok := t.MemoryManager().LookupAIOContext(t, id) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Setup the timeout. @@ -114,7 +115,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, err } if !d.Valid() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } deadline = t.Kernel().MonotonicClock().Now().Add(d.ToDuration()) haveDeadline = true @@ -134,7 +135,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S var err error v, err = waitForRequest(ctx, t, haveDeadline, deadline) if err != nil { - if count > 0 || err == syserror.ETIMEDOUT { + if count > 0 || linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return uintptr(count), nil, nil } return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) @@ -171,7 +172,7 @@ func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadl done := ctx.WaitChannel() if done == nil { // Context has been destroyed. - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } if err := t.BlockWithDeadline(done, haveDeadline, deadline); err != nil { return nil, err @@ -184,7 +185,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) bytes := int(cb.Bytes) if bytes < 0 { // Linux also requires that this field fit in ssize_t. - return usermem.IOSequence{}, syserror.EINVAL + return usermem.IOSequence{}, linuxerr.EINVAL } // Since this I/O will be asynchronous with respect to t's task goroutine, @@ -206,7 +207,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) default: // Not a supported command. - return usermem.IOSequence{}, syserror.EINVAL + return usermem.IOSequence{}, linuxerr.EINVAL } } @@ -286,7 +287,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host // Check that it is an eventfd. if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok { // Not an event FD. - return syserror.EINVAL + return linuxerr.EINVAL } } @@ -299,14 +300,14 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host switch cb.OpCode { case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV: if cb.Offset < 0 { - return syserror.EINVAL + return linuxerr.EINVAL } } // Prepare the request. ctx, ok := t.MemoryManager().LookupAIOContext(t, id) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } if err := ctx.Prepare(); err != nil { return err @@ -335,7 +336,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc addr := args[2].Pointer() if nrEvents < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } for i := int32(0); i < nrEvents; i++ { diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go index d3b85e11b..782bcb94f 100644 --- a/pkg/sentry/syscalls/linux/sys_capability.go +++ b/pkg/sentry/syscalls/linux/sys_capability.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -24,7 +25,7 @@ import ( func lookupCaps(t *kernel.Task, tid kernel.ThreadID) (permitted, inheritable, effective auth.CapabilitySet, err error) { if tid < 0 { - err = syserror.EINVAL + err = linuxerr.EINVAL return } if tid > 0 { @@ -97,7 +98,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err } if dataAddr != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil } @@ -144,6 +145,6 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if _, err := hdr.CopyOut(t, hdrAddr); err != nil { return 0, nil, err } - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index 69cbc98d0..daa151bb4 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -31,7 +32,7 @@ import ( func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { flags := args[0].Int() if flags & ^linux.EPOLL_CLOEXEC != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } closeOnExec := flags&linux.EPOLL_CLOEXEC != 0 @@ -48,7 +49,7 @@ func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S size := args[0].Int() if size <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } fd, err := syscalls.CreateEpoll(t, false) @@ -101,7 +102,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc mask |= waiter.EventHUp | waiter.EventErr return 0, nil, syscalls.UpdateEpoll(t, epfd, fd, flags, mask, data) default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go index 3b4f879e4..7ba9a755e 100644 --- a/pkg/sentry/syscalls/linux/sys_eventfd.go +++ b/pkg/sentry/syscalls/linux/sys_eventfd.go @@ -16,11 +16,11 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd" - "gvisor.dev/gvisor/pkg/syserror" ) // Eventfd2 implements linux syscall eventfd2(2). @@ -30,7 +30,7 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC) if flags & ^allOps != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } event := eventfd.New(t, uint64(initVal), flags&linux.EFD_SEMAPHORE != 0) diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 9cd238efd..3d45341f2 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -275,7 +276,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMod default: // "EINVAL - mode requested creation of something other than a // regular file, device special file, FIFO or socket." - mknod(2) - return syserror.EINVAL + return linuxerr.EINVAL } }) } @@ -394,8 +395,8 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode } var newFile *fs.File - switch err { - case nil: + switch { + case err == nil: // Like sys_open, check for a few things about the // filesystem before trying to get a reference to the // fs.File. The same constraints on Check apply. @@ -418,7 +419,7 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } defer newFile.DecRef(t) - case syserror.ENOENT: + case linuxerr.Equals(linuxerr.ENOENT, err): // File does not exist. Proceed with creation. // Do we have write permissions on the parent? @@ -527,7 +528,7 @@ func accessAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode uint) error // Sanity check the mode. if mode&^(rOK|wOK|xOK) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } return fileOpOn(t, dirFD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { @@ -843,7 +844,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC flags := args[2].Uint() if oldfd == newfd { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } oldFile := t.GetFile(oldfd) @@ -905,7 +906,7 @@ func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error { if who < 0 { // Check for overflow before flipping the sign. if who-1 > who { - return syserror.EINVAL + return linuxerr.EINVAL } pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who)) a.SetOwnerProcessGroup(t, pg) @@ -976,7 +977,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case 2: sw = fs.SeekEnd default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Compute the lock offset. @@ -995,7 +996,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } off = uattr.Size default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Compute the lock range. @@ -1043,7 +1044,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall file.Dirent.Inode.LockCtx.Posix.UnlockRegion(t.FDTable(), rng) return 0, nil, nil default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } case linux.F_GETOWN: return uintptr(fGetOwn(t, file)), nil, nil @@ -1085,7 +1086,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall a.SetOwnerProcessGroup(t, pg) return 0, nil, nil default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } case linux.F_GET_SEALS: val, err := tmpfs.GetSeals(file.Dirent.Inode) @@ -1099,14 +1100,14 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETPIPE_SZ: sz, ok := file.FileOperations.(fs.FifoSizer) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } size, err := sz.FifoSize(t, file) return uintptr(size), nil, err case linux.F_SETPIPE_SZ: sz, ok := file.FileOperations.(fs.FifoSizer) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } n, err := sz.SetFifoSize(int64(args[2].Int())) return uintptr(n), nil, err @@ -1118,7 +1119,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } @@ -1131,7 +1132,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // Note: offset is allowed to be negative. if length < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFile(fd) @@ -1153,7 +1154,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys case linux.POSIX_FADV_DONTNEED: case linux.POSIX_FADV_NOREUSE: default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Sure, whatever. @@ -1178,12 +1179,12 @@ func mkdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMod // Does this directory exist already? remainingTraversals := uint(linux.MaxSymlinkTraversals) f, err := t.MountNamespace().FindInode(t, root, d, name, &remainingTraversals) - switch err { - case nil: + switch { + case err == nil: // The directory existed. defer f.DecRef(t) return syserror.EEXIST - case syserror.EACCES: + case linuxerr.Equals(linuxerr.EACCES, err): // Permission denied while walking to the directory. return err default: @@ -1236,7 +1237,7 @@ func rmdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr) error { // dot vs. double dots. switch name { case ".": - return syserror.EINVAL + return linuxerr.EINVAL case "..": return syserror.ENOTEMPTY } @@ -1431,7 +1432,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Sanity check flags. if flags&^(linux.AT_SYMLINK_FOLLOW|linux.AT_EMPTY_PATH) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } resolve := flags&linux.AT_SYMLINK_FOLLOW == linux.AT_SYMLINK_FOLLOW @@ -1464,8 +1465,8 @@ func readlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, bufAddr hostarc } s, err := d.Inode.Readlink(t) - if err == syserror.ENOLINK { - return syserror.EINVAL + if linuxerr.Equals(linuxerr.ENOLINK, err) { + return linuxerr.EINVAL } if err != nil { return err @@ -1557,7 +1558,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc length := args[1].Int64() if length < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */) @@ -1565,13 +1566,13 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, err } if dirPath { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } @@ -1583,7 +1584,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // In contrast to open(O_TRUNC), truncate(2) is only valid for file // types. if !fs.IsFile(d.Inode.StableAttr) { - return syserror.EINVAL + return linuxerr.EINVAL } // Reject truncation if the access permissions do not allow truncation. @@ -1617,24 +1618,24 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // Reject truncation if the file flags do not permit this operation. // This is different from truncate(2) above. if !file.Flags().Write { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // In contrast to open(O_TRUNC), truncate(2) is only valid for file // types. Note that this is different from truncate(2) above, where a // directory returns EISDIR. if !fs.IsFile(file.Dirent.Inode.StableAttr) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if length < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } @@ -1673,14 +1674,16 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { if err != nil { return err } + c := t.Credentials() hasCap := d.Inode.CheckCapability(t, linux.CAP_CHOWN) isOwner := uattr.Owner.UID == c.EffectiveKUID + var clearPrivilege bool if uid.Ok() { kuid := c.UserNamespace.MapToKUID(uid) // Valid UID must be supplied if UID is to be changed. if !kuid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } // "Only a privileged process (CAP_CHOWN) may change the owner @@ -1693,13 +1696,18 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { return syserror.EPERM } + // The setuid and setgid bits are cleared during a chown. + if uattr.Owner.UID != kuid { + clearPrivilege = true + } + owner.UID = kuid } if gid.Ok() { kgid := c.UserNamespace.MapToKGID(gid) // Valid GID must be supplied if GID is to be changed. if !kgid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } // "The owner of a file may change the group of the file to any @@ -1711,6 +1719,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { return syserror.EPERM } + // The setuid and setgid bits are cleared during a chown. + if uattr.Owner.GID != kgid { + clearPrivilege = true + } + owner.GID = kgid } @@ -1721,10 +1734,14 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error { if err := d.Inode.SetOwner(t, d, owner); err != nil { return err } + // Clear privilege bits if needed and they are set. + if clearPrivilege && uattr.Perms.HasSetUIDOrGID() && !fs.IsDir(d.Inode.StableAttr) { + uattr.Perms.DropSetUIDAndMaybeGID() + if !d.Inode.SetPermissions(t, d, uattr.Perms) { + return syserror.EPERM + } + } - // When the owner or group are changed by an unprivileged user, - // chown(2) also clears the set-user-ID and set-group-ID bits, but - // we do not support them. return nil } @@ -1792,7 +1809,7 @@ func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc flags := args[4].Int() if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, chownAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, flags&linux.AT_EMPTY_PATH != 0, uid, gid) @@ -1897,7 +1914,7 @@ func utimes(t *kernel.Task, dirFD int32, addr hostarch.Addr, ts fs.TimeSpec, res if addr == 0 && dirFD != linux.AT_FDCWD { if !resolve { // Linux returns EINVAL in this case. See utimes.c. - return syserror.EINVAL + return linuxerr.EINVAL } f := t.GetFile(dirFD) if f == nil { @@ -1980,7 +1997,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // If both are UTIME_OMIT, this is a noop. @@ -2015,7 +2032,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } if times[0].Usec >= 1e6 || times[0].Usec < 0 || times[1].Usec >= 1e6 || times[1].Usec < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ts = fs.TimeSpec{ @@ -2101,7 +2118,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer file.DecRef(t) if offset < 0 || length <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if mode != 0 { t.Kernel().EmitUnimplementedEvent(t) @@ -2124,9 +2141,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EFBIG } if uint64(size) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } @@ -2191,7 +2208,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall file.Dirent.Inode.LockCtx.BSD.UnlockRegion(file, rng) default: // flock(2): EINVAL operation is invalid. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil @@ -2210,7 +2227,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if flags&^memfdAllFlags != 0 { // Unknown bits in flags. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } allowSeals := flags&linux.MFD_ALLOW_SEALING != 0 @@ -2221,7 +2238,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, err } if len(name) > memfdMaxNameLen { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } name = memfdPrefix + name diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index eeea1613b..a9e9126cc 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -210,7 +211,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // WAIT_BITSET uses an absolute timeout which is either // CLOCK_MONOTONIC or CLOCK_REALTIME. if mask == 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } n, err := futexWaitAbsolute(t, clockRealtime, timespec, forever, addr, private, uint32(val), mask) return n, nil, err @@ -224,7 +225,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FUTEX_WAKE_BITSET: if mask == 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if val <= 0 { // The Linux kernel wakes one waiter even if val is @@ -295,7 +296,7 @@ func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel length := args[1].SizeT() if length != uint(linux.SizeOfRobustListHead) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } t.SetRobustList(head) return 0, nil, nil @@ -310,7 +311,7 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel sizeAddr := args[2].Pointer() if tid < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ot := t diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index bbba71d8f..355fbd766 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -19,6 +19,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -38,7 +39,7 @@ func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc minSize := int(smallestDirent(t.Arch())) if size < minSize { // size is smaller than smallest possible dirent. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } n, err := getdents(t, fd, addr, size, (*dirent).Serialize) @@ -54,7 +55,7 @@ func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy minSize := int(smallestDirent64(t.Arch())) if size < minSize { // size is smaller than smallest possible dirent. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } n, err := getdents(t, fd, addr, size, (*dirent).Serialize64) diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go index a29d307e5..50fcadb58 100644 --- a/pkg/sentry/syscalls/linux/sys_identity.go +++ b/pkg/sentry/syscalls/linux/sys_identity.go @@ -15,10 +15,10 @@ package linux import ( + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -142,7 +142,7 @@ func Setresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { size := int(args[0].Int()) if size < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } kgids := t.Credentials().ExtraKGIDs // "If size is zero, list is not modified, but the total number of @@ -151,7 +151,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return uintptr(len(kgids)), nil, nil } if size < len(kgids) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } gids := make([]auth.GID, len(kgids)) for i, kgid := range kgids { @@ -167,7 +167,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { size := args[0].Int() if size < 0 || size > maxNGroups { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if size == 0 { return 0, nil, t.SetExtraGIDs(nil) diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go index cf47bb9dd..48c8dbdca 100644 --- a/pkg/sentry/syscalls/linux/sys_inotify.go +++ b/pkg/sentry/syscalls/linux/sys_inotify.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" @@ -30,7 +31,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. flags := int(args[0].Int()) if flags&^allFlags != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } dirent := fs.NewDirent(t, anon.NewInode(t), "inotify") @@ -72,7 +73,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*fs.Inotify, *fs.File, error) { if !ok { // Not an inotify fd. file.DecRef(t) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } return ino, file, nil @@ -91,7 +92,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern // "EINVAL: The given event mask contains no valid events." // -- inotify_add_watch(2) if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ino, file, err := fdToInotify(t, fd) diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go index 0046347cb..c16c63ecc 100644 --- a/pkg/sentry/syscalls/linux/sys_lseek.go +++ b/pkg/sentry/syscalls/linux/sys_lseek.go @@ -15,6 +15,7 @@ package linux import ( + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -44,7 +45,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case 2: sw = fs.SeekEnd default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } offset, serr := file.Seek(t, sw, offset) diff --git a/pkg/sentry/syscalls/linux/sys_membarrier.go b/pkg/sentry/syscalls/linux/sys_membarrier.go index 63ee5d435..4b67f2536 100644 --- a/pkg/sentry/syscalls/linux/sys_membarrier.go +++ b/pkg/sentry/syscalls/linux/sys_membarrier.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -29,7 +30,7 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy switch cmd { case linux.MEMBARRIER_CMD_QUERY: if flags != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var supportedCommands uintptr if t.Kernel().Platform.HaveGlobalMemoryBarrier() { @@ -46,10 +47,10 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return supportedCommands, nil, nil case linux.MEMBARRIER_CMD_GLOBAL, linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED: if flags != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if !t.Kernel().Platform.HaveGlobalMemoryBarrier() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if cmd == linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED && !t.MemoryManager().IsMembarrierPrivateEnabled() { return 0, nil, syserror.EPERM @@ -57,28 +58,28 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, t.Kernel().Platform.GlobalMemoryBarrier() case linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED: if flags != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if !t.Kernel().Platform.HaveGlobalMemoryBarrier() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // no-op return 0, nil, nil case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED: if flags != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if !t.Kernel().Platform.HaveGlobalMemoryBarrier() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } t.MemoryManager().EnableMembarrierPrivate() return 0, nil, nil case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ: if flags&^linux.MEMBARRIER_CMD_FLAG_CPU != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if !t.RSeqAvailable() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if !t.MemoryManager().IsMembarrierRSeqEnabled() { return 0, nil, syserror.EPERM @@ -88,16 +89,16 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, t.Kernel().Platform.PreemptAllCPUs() case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ: if flags != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if !t.RSeqAvailable() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } t.MemoryManager().EnableMembarrierRSeq() return 0, nil, nil default: // Probably a command we don't implement. t.Kernel().EmitUnimplementedEvent(t) - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go index 6d27f4292..62ec3e27f 100644 --- a/pkg/sentry/syscalls/linux/sys_mempolicy.go +++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -43,7 +44,7 @@ func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64, // maxnode-1, not maxnode, as the number of bits. bits := maxnode - 1 if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0 - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if bits == 0 { return 0, nil @@ -58,12 +59,12 @@ func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64, // Check that only allowed bits in the first unsigned long in the nodemask // are set. if val&^allowedNodemask != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Check that all remaining bits in the nodemask are 0. for i := 8; i < len(buf); i++ { if buf[i] != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } return val, nil @@ -74,7 +75,7 @@ func copyOutNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32, val uin // bits. bits := maxnode - 1 if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0 - return syserror.EINVAL + return linuxerr.EINVAL } if bits == 0 { return nil @@ -110,7 +111,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. flags := args[4].Uint() if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } nodeFlag := flags&linux.MPOL_F_NODE != 0 addrFlag := flags&linux.MPOL_F_ADDR != 0 @@ -119,7 +120,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. // "EINVAL: The value specified by maxnode is less than the number of node // IDs supported by the system." - get_mempolicy(2) if nodemask != 0 && maxnode < maxNodes { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is @@ -130,7 +131,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. // "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either // MPOL_F_ADDR or MPOL_F_NODE." if nodeFlag || addrFlag { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil { return 0, nil, err @@ -184,7 +185,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. // mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will // just (usually) fail to find a VMA at address 0 and return EFAULT. if addr != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // "If flags is specified as 0, then information about the calling thread's @@ -198,7 +199,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. policy, nodemaskVal := t.NumaPolicy() if nodeFlag { if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } policy = linux.MPOL_DEFAULT // maxNodes == 1 } @@ -240,7 +241,7 @@ func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall flags := args[5].Uint() if flags&^linux.MPOL_MF_VALID != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be // privileged (CAP_SYS_NICE) to use this flag." - mbind(2) @@ -264,11 +265,11 @@ func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nod mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS) if flags == linux.MPOL_MODE_FLAGS { // Can't specify both mode flags simultaneously. - return 0, 0, syserror.EINVAL + return 0, 0, linuxerr.EINVAL } if mode < 0 || mode >= linux.MPOL_MAX { // Must specify a valid mode. - return 0, 0, syserror.EINVAL + return 0, 0, linuxerr.EINVAL } var nodemaskVal uint64 @@ -285,22 +286,22 @@ func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nod // "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate; // Linux allows a nodemask to be specified, as long as it is empty. if nodemaskVal != 0 { - return 0, 0, syserror.EINVAL + return 0, 0, linuxerr.EINVAL } case linux.MPOL_BIND, linux.MPOL_INTERLEAVE: // These require a non-empty nodemask. if nodemaskVal == 0 { - return 0, 0, syserror.EINVAL + return 0, 0, linuxerr.EINVAL } case linux.MPOL_PREFERRED: // This permits an empty nodemask, as long as no flags are set. if nodemaskVal == 0 && flags != 0 { - return 0, 0, syserror.EINVAL + return 0, 0, linuxerr.EINVAL } case linux.MPOL_LOCAL: // This requires an empty nodemask and no flags set ... if nodemaskVal != 0 || flags != 0 { - return 0, 0, syserror.EINVAL + return 0, 0, linuxerr.EINVAL } // ... and is implemented as MPOL_PREFERRED. mode = linux.MPOL_PREFERRED diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 70da0707d..74279c82b 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -18,13 +18,13 @@ import ( "bytes" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Brk implements linux syscall brk(2). @@ -51,7 +51,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Require exactly one of MAP_PRIVATE and MAP_SHARED. if private == shared { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } opts := memmap.MMapOpts{ @@ -132,7 +132,7 @@ func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal newAddr := args[4].Pointer() if flags&^(linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } mayMove := flags&linux.MREMAP_MAYMOVE != 0 fixed := flags&linux.MREMAP_FIXED != 0 @@ -147,7 +147,7 @@ func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case !mayMove && fixed: // "If MREMAP_FIXED is specified, then MREMAP_MAYMOVE must also be // specified." - mremap(2) - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } rv, err := t.MemoryManager().MRemap(t, oldAddr, oldSize, newSize, mm.MRemapOpts{ @@ -178,7 +178,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // "The Linux implementation requires that the address addr be // page-aligned, and allows length to be zero." - madvise(2) if addr.RoundDown() != addr { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if length == 0 { return 0, nil, nil @@ -186,7 +186,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Not explicitly stated: length need not be page-aligned. lenAddr, ok := hostarch.Addr(length).RoundUp() if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } length = uint64(lenAddr) @@ -217,7 +217,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, syserror.EPERM default: // If adv is not a valid value tell the caller. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } @@ -228,7 +228,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca vec := args[2].Pointer() if addr != addr.RoundDown() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // "The length argument need not be a multiple of the page size, but since // residency information is returned for whole pages, length is effectively @@ -265,11 +265,11 @@ func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // semantics that are (currently) equivalent to specifying MS_ASYNC." - // msync(2) if flags&^(linux.MS_ASYNC|linux.MS_SYNC|linux.MS_INVALIDATE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } sync := flags&linux.MS_SYNC != 0 if sync && flags&linux.MS_ASYNC != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } err := t.MemoryManager().MSync(t, addr, uint64(length), mm.MSyncOpts{ Sync: sync, @@ -295,7 +295,7 @@ func Mlock2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal flags := args[2].Int() if flags&^(linux.MLOCK_ONFAULT) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } mode := memmap.MLockEager @@ -318,7 +318,7 @@ func Mlockall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc flags := args[0].Int() if flags&^(linux.MCL_CURRENT|linux.MCL_FUTURE|linux.MCL_ONFAULT) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } mode := memmap.MLockEager diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go index 864d2138c..8bf4e9f06 100644 --- a/pkg/sentry/syscalls/linux/sys_mount.go +++ b/pkg/sentry/syscalls/linux/sys_mount.go @@ -16,12 +16,12 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Mount implements Linux syscall mount(2). @@ -83,7 +83,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // unknown or unsupported flags are passed. Since we don't implement // everything, we fail explicitly on flags that are unimplemented. if flags&(unsupportedOps|unsupportedFlags) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } rsys, ok := fs.FindFilesystem(fsType) @@ -107,7 +107,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall rootInode, err := rsys.Mount(t, sourcePath, superFlags, data, nil) if err != nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if err := fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { @@ -130,7 +130,7 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE if flags&unsupported != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } path, _, err := copyInPath(t, addr, false /* allowEmpty */) diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index d95034347..5925c2263 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -16,13 +16,13 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange @@ -30,7 +30,7 @@ import ( // pipe2 implements the actual system call with flags. func pipe2(t *kernel.Task, addr hostarch.Addr, flags uint) (uintptr, error) { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize) diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index da548a14a..f2056d850 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -128,7 +129,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // Wait for a notification. timeout, err = t.BlockWithTimeout(ch, !forever, timeout) if err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = nil } return timeout, 0, err @@ -157,7 +158,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max. func CopyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) { if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } pfd := make([]linux.PollFD, nfds) @@ -217,7 +218,7 @@ func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialB func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Calculate the size of the fd sets (one bit per fd). @@ -404,7 +405,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) { remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) // On an interrupt poll(2) is restarted with the remaining timeout. - if err == syserror.EINTR { + if linuxerr.Equals(linuxerr.EINTR, err) { t.SetSyscallRestartBlock(&pollRestartBlock{ pfdAddr: pfdAddr, nfds: nfds, @@ -463,7 +464,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -485,7 +486,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err } if timeval.Sec < 0 || timeval.Usec < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } timeout = time.Duration(timeval.ToNsecCapped()) } @@ -493,7 +494,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -538,7 +539,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 9890dd946..534f1e632 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -38,7 +39,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.PR_SET_PDEATHSIG: sig := linux.Signal(args[1].Int()) if sig != 0 && !sig.IsValid() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } t.SetParentDeathSignal(sig) return 0, nil, nil @@ -69,7 +70,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall d = mm.UserDumpable default: // N.B. Userspace may not pass SUID_DUMP_ROOT. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } t.MemoryManager().SetDumpability(d) return 0, nil, nil @@ -90,7 +91,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } else if val == 1 { t.SetKeepCaps(true) } else { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil @@ -98,7 +99,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.PR_SET_NAME: addr := args[1].Pointer() name, err := t.CopyInString(addr, linux.TASK_COMM_LEN-1) - if err != nil && err != syserror.ENAMETOOLONG { + if err != nil && !linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { return 0, nil, err } t.SetName(name) @@ -155,12 +156,12 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall t.Kernel().EmitUnimplementedEvent(t) fallthrough default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } case linux.PR_SET_NO_NEW_PRIVS: if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // PR_SET_NO_NEW_PRIVS is assumed to always be set. // See kernel.Task.updateCredsForExecLocked. @@ -168,7 +169,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.PR_GET_NO_NEW_PRIVS: if args[1].Int() != 0 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 1, nil, nil @@ -184,7 +185,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall default: tracer := t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) if tracer == nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } t.SetYAMAException(tracer) return 0, nil, nil @@ -193,7 +194,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.PR_SET_SECCOMP: if args[1].Int() != linux.SECCOMP_MODE_FILTER { // Unsupported mode. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, seccomp(t, linux.SECCOMP_SET_MODE_FILTER, 0, args[2].Pointer()) @@ -204,7 +205,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.PR_CAPBSET_READ: cp := linux.Capability(args[1].Uint64()) if !cp.Ok() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var rv uintptr if auth.CapabilitySetOf(cp)&t.Credentials().BoundingCaps != 0 { @@ -215,7 +216,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.PR_CAPBSET_DROP: cp := linux.Capability(args[1].Uint64()) if !cp.Ok() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, t.DropBoundingCapability(cp) @@ -240,7 +241,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall t.Kernel().EmitUnimplementedEvent(t) fallthrough default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go index ae545f80f..ec6c80de5 100644 --- a/pkg/sentry/syscalls/linux/sys_random.go +++ b/pkg/sentry/syscalls/linux/sys_random.go @@ -18,14 +18,14 @@ import ( "io" "math" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -47,7 +47,7 @@ func GetRandom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // Flags are checked for validity but otherwise ignored. See above. if flags & ^(_GRND_NONBLOCK|_GRND_RANDOM) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if length > math.MaxInt32 { diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 13e5e3a51..4064467a9 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -58,7 +59,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the destination of the read. @@ -93,18 +94,18 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // Check that the size is valid. if int(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Return EINVAL; if the underlying file type does not support readahead, // then Linux will return EINVAL to indicate as much. In the future, we // may extend this function to actually support readahead hints. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Pread64 implements linux syscall pread64(2). @@ -122,7 +123,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is reading at an offset supported? @@ -138,7 +139,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the destination of the read. @@ -199,7 +200,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Check that the offset is legitimate. if offset < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is reading at an offset supported? @@ -248,7 +249,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Check that the offset is legitimate. if offset < -1 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is reading at an offset supported? @@ -331,7 +332,7 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index e64246d57..ca78c2ab2 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -129,7 +130,7 @@ func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys resource, ok := limits.FromLinuxResource[int(args[0].Int())] if !ok { // Return err; unknown limit. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } addr := args[1].Pointer() rlim, err := newRlimit(t) @@ -150,7 +151,7 @@ func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys resource, ok := limits.FromLinuxResource[int(args[0].Int())] if !ok { // Return err; unknown limit. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } addr := args[1].Pointer() rlim, err := newRlimit(t) @@ -170,7 +171,7 @@ func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys resource, ok := limits.FromLinuxResource[int(args[1].Int())] if !ok { // Return err; unknown limit. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } newRlimAddr := args[2].Pointer() oldRlimAddr := args[3].Pointer() @@ -185,7 +186,7 @@ func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } if tid < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ot := t if tid > 0 { diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go index 90db10ea6..5fe196647 100644 --- a/pkg/sentry/syscalls/linux/sys_rseq.go +++ b/pkg/sentry/syscalls/linux/sys_rseq.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -43,6 +44,6 @@ func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return 0, nil, t.ClearRSeq(addr, length, signature) default: // Unknown flag. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go index ac5c98a54..a689abcc9 100644 --- a/pkg/sentry/syscalls/linux/sys_rusage.go +++ b/pkg/sentry/syscalls/linux/sys_rusage.go @@ -16,11 +16,11 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usage" - "gvisor.dev/gvisor/pkg/syserror" ) func getrusage(t *kernel.Task, which int32) linux.Rusage { @@ -76,7 +76,7 @@ func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys addr := args[1].Pointer() if which != linux.RUSAGE_SELF && which != linux.RUSAGE_CHILDREN && which != linux.RUSAGE_THREAD { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ru := getrusage(t, which) diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go index bfcf44b6f..b5e7b70b5 100644 --- a/pkg/sentry/syscalls/linux/sys_sched.go +++ b/pkg/sentry/syscalls/linux/sys_sched.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -38,10 +39,10 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel pid := args[0].Int() param := args[1].Pointer() if param == 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if pid < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil { return 0, nil, syserror.ESRCH @@ -58,7 +59,7 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel func SchedGetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { pid := args[0].Int() if pid < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil { return 0, nil, syserror.ESRCH @@ -72,20 +73,20 @@ func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ke policy := args[1].Int() param := args[2].Pointer() if pid < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if policy != onlyScheduler { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil { return 0, nil, syserror.ESRCH } var r SchedParam if _, err := r.CopyIn(t, param); err != nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if r.schedPriority != onlyPriority { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil } diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go index e16d6ff3f..b0dc84b8d 100644 --- a/pkg/sentry/syscalls/linux/sys_seccomp.go +++ b/pkg/sentry/syscalls/linux/sys_seccomp.go @@ -17,10 +17,10 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // userSockFprog is equivalent to Linux's struct sock_fprog on amd64. @@ -44,7 +44,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error { // We only support SECCOMP_SET_MODE_FILTER at the moment. if mode != linux.SECCOMP_SET_MODE_FILTER { // Unsupported mode. - return syserror.EINVAL + return linuxerr.EINVAL } tsync := flags&linux.SECCOMP_FILTER_FLAG_TSYNC != 0 @@ -52,7 +52,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error { // The only flag we support now is SECCOMP_FILTER_FLAG_TSYNC. if flags&^linux.SECCOMP_FILTER_FLAG_TSYNC != 0 { // Unsupported flag. - return syserror.EINVAL + return linuxerr.EINVAL } var fprog userSockFprog @@ -66,7 +66,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error { compiledFilter, err := bpf.Compile(filter) if err != nil { t.Debugf("Invalid seccomp-bpf filter: %v", err) - return syserror.EINVAL + return linuxerr.EINVAL } return t.AppendSyscallFilter(compiledFilter, tsync) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index c84260080..d12a4303b 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -61,10 +62,10 @@ func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy nsops := args[2].SizeT() timespecAddr := args[3].Pointer() if nsops <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if nsops > opsMax { - return 0, nil, syserror.E2BIG + return 0, nil, linuxerr.E2BIG } ops := make([]linux.Sembuf, nsops) @@ -77,11 +78,11 @@ func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, err } if timeout.Sec < 0 || timeout.Nsec < 0 || timeout.Nsec >= 1e9 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if err := semTimedOp(t, id, ops, true, timeout.ToDuration()); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { return 0, nil, syserror.EAGAIN } return 0, nil, err @@ -96,10 +97,10 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall nsops := args[2].SizeT() if nsops <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if nsops > opsMax { - return 0, nil, syserror.E2BIG + return 0, nil, linuxerr.E2BIG } ops := make([]linux.Sembuf, nsops) @@ -113,7 +114,7 @@ func semTimedOp(t *kernel.Task, id int32, ops []linux.Sembuf, haveTimeout bool, set := t.IPCNamespace().SemaphoreRegistry().FindByID(id) if set == nil { - return syserror.EINVAL + return linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup()) @@ -232,7 +233,7 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return uintptr(semid), nil, err default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } @@ -246,17 +247,17 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return syserror.EINVAL + return linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) kuid := creds.UserNamespace.MapToKUID(uid) if !kuid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } kgid := creds.UserNamespace.MapToKGID(gid) if !kgid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } owner := fs.FileOwner{UID: kuid, GID: kgid} return set.Change(t, creds, owner, perms) @@ -266,7 +267,7 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) return set.GetStat(creds) @@ -276,7 +277,7 @@ func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByIndex(index) if set == nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) ds, err := set.GetStat(creds) @@ -289,7 +290,7 @@ func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { func semStatAny(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { set := t.IPCNamespace().SemaphoreRegistry().FindByIndex(index) if set == nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) ds, err := set.GetStatAny(creds) @@ -303,7 +304,7 @@ func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return syserror.EINVAL + return linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup()) @@ -314,7 +315,7 @@ func setValAll(t *kernel.Task, id int32, array hostarch.Addr) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return syserror.EINVAL + return linuxerr.EINVAL } vals := make([]uint16, set.Size()) if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil { @@ -329,7 +330,7 @@ func getVal(t *kernel.Task, id int32, num int32) (int16, error) { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) return set.GetVal(num, creds) @@ -339,7 +340,7 @@ func getValAll(t *kernel.Task, id int32, array hostarch.Addr) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return syserror.EINVAL + return linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) vals, err := set.GetValAll(creds) @@ -354,7 +355,7 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) gpid, err := set.GetPID(num, creds) @@ -373,7 +374,7 @@ func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) return set.CountZeroWaiters(num, creds) @@ -383,7 +384,7 @@ func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) if set == nil { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } creds := auth.CredentialsFromContext(t) return set.CountNegativeWaiters(num, creds) diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index 584064143..3e3a952ce 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -16,10 +16,10 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/shm" - "gvisor.dev/gvisor/pkg/syserror" ) // Shmget implements shmget(2). @@ -51,7 +51,7 @@ func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) { segment := r.FindByID(id) if segment == nil { // No segment with provided id. - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } return segment, nil } @@ -64,7 +64,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall segment, err := findSegment(t, id) if err != nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } defer segment.DecRef(t) @@ -106,7 +106,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case linux.IPC_STAT: segment, err := findSegment(t, id) if err != nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } defer segment.DecRef(t) @@ -130,7 +130,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Remaining commands refer to a specific segment. segment, err := findSegment(t, id) if err != nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } defer segment.DecRef(t) @@ -155,6 +155,6 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, nil default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 53b12dc41..4d659e5cf 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -84,13 +85,13 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if !mayKill(t, target, sig) { return 0, nil, syserror.EPERM } - info := &arch.SignalInfo{ + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) - if err := target.SendGroupSignal(info); err != syserror.ESRCH { + if err := target.SendGroupSignal(info); !linuxerr.Equals(linuxerr.ESRCH, err) { return 0, nil, err } } @@ -123,14 +124,14 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // depend on the iteration order. We at least implement the // semantics documented by the man page: "On success (at least // one signal was sent), zero is returned." - info := &arch.SignalInfo{ + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) err := tg.SendSignal(info) - if err == syserror.ESRCH { + if linuxerr.Equals(linuxerr.ESRCH, err) { // ESRCH is ignored because it means the task // exited while we were iterating. This is a // race which would not normally exist on @@ -167,14 +168,14 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC continue } - info := &arch.SignalInfo{ + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) // See note above regarding ESRCH race above. - if err := tg.SendSignal(info); err != syserror.ESRCH { + if err := tg.SendSignal(info); !linuxerr.Equals(linuxerr.ESRCH, err) { lastErr = err } } @@ -184,10 +185,10 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC } } -func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalInfo { - info := &arch.SignalInfo{ +func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *linux.SignalInfo { + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoTkill, + Code: linux.SI_TKILL, } info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) @@ -202,7 +203,7 @@ func Tkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // N.B. Inconsistent with man page, linux actually rejects calls with // tid <=0 by EINVAL. This isn't the same for all signal calls. if tid <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } target := t.PIDNamespace().TaskWithID(tid) @@ -225,7 +226,7 @@ func Tgkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // N.B. Inconsistent with man page, linux actually rejects calls with // tgid/tid <=0 by EINVAL. This isn't the same for all signal calls. if tgid <= 0 || tid <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } targetTG := t.PIDNamespace().ThreadGroupWithID(tgid) @@ -248,23 +249,23 @@ func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S sigsetsize := args[3].SizeT() if sigsetsize != linux.SignalSetSize { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } - var newactptr *arch.SignalAct + var newactptr *linux.SigAction if newactarg != 0 { - newact, err := t.CopyInSignalAct(newactarg) - if err != nil { + var newact linux.SigAction + if _, err := newact.CopyIn(t, newactarg); err != nil { return 0, nil, err } newactptr = &newact } - oldact, err := t.ThreadGroup().SetSignalAct(sig, newactptr) + oldact, err := t.ThreadGroup().SetSigAction(sig, newactptr) if err != nil { return 0, nil, err } if oldactarg != 0 { - if err := t.CopyOutSignalAct(oldactarg, &oldact); err != nil { + if _, err := oldact.CopyOut(t, oldactarg); err != nil { return 0, nil, err } } @@ -291,7 +292,7 @@ func RtSigprocmask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel sigsetsize := args[3].SizeT() if sigsetsize != linux.SignalSetSize { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } oldmask := t.SignalMask() if setaddr != 0 { @@ -308,7 +309,7 @@ func RtSigprocmask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel case linux.SIG_SETMASK: t.SetSignalMask(mask) default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } if oldaddr != 0 { @@ -325,13 +326,12 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S alt := t.SignalStack() if oldaddr != 0 { - if err := t.CopyOutSignalStack(oldaddr, &alt); err != nil { + if _, err := alt.CopyOut(t, oldaddr); err != nil { return 0, nil, err } } if setaddr != 0 { - alt, err := t.CopyInSignalStack(setaddr) - if err != nil { + if _, err := alt.CopyIn(t, setaddr); err != nil { return 0, nil, err } // The signal stack cannot be changed if the task is currently @@ -378,7 +378,7 @@ func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne return 0, nil, err } if !d.Valid() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } timeout = time.Duration(d.ToNsecCapped()) } else { @@ -410,7 +410,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne // We must ensure that the Signo is set (Linux overrides this in the // same way), and that the code is in the allowed set. This same logic // appears below in RtSigtgqueueinfo and should be kept in sync. - var info arch.SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(t, infoAddr); err != nil { return 0, nil, err } @@ -426,7 +426,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne // If the sender is not the receiver, it can't use si_codes used by the // kernel or SI_TKILL. - if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t { + if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t { return 0, nil, syserror.EPERM } @@ -434,7 +434,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne return 0, nil, syserror.EPERM } - if err := target.SendGroupSignal(&info); err != syserror.ESRCH { + if err := target.SendGroupSignal(&info); !linuxerr.Equals(linuxerr.ESRCH, err) { return 0, nil, err } } @@ -450,11 +450,11 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker // N.B. Inconsistent with man page, linux actually rejects calls with // tgid/tid <=0 by EINVAL. This isn't the same for all signal calls. if tgid <= 0 || tid <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Copy in the info. See RtSigqueueinfo above. - var info arch.SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(t, infoAddr); err != nil { return 0, nil, err } @@ -469,7 +469,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker // If the sender is not the receiver, it can't use si_codes used by the // kernel or SI_TKILL. - if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t { + if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t { return 0, nil, syserror.EPERM } @@ -525,7 +525,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u // Always check for valid flags, even if not creating. if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is this a change to an existing signalfd? @@ -545,7 +545,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u } // Not a signalfd. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Create a new file. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index e07917613..6638ad60f 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" @@ -117,7 +118,7 @@ type multipleMessageHeader64 struct { // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) { if addrlen > maxAddrLen { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } addrBuf := make([]byte, addrlen) @@ -139,7 +140,7 @@ func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr h } if int32(bufLen) < 0 { - return syserror.EINVAL + return linuxerr.EINVAL } // Write the length unconditionally. @@ -173,7 +174,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Check and initialize the flags. if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Create the new socket. @@ -205,7 +206,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Check and initialize the flags. if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } fileFlags := fs.SettableFileFlags{ @@ -277,7 +278,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) { // Check that no unsupported flags are passed in. if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -305,7 +306,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, if peerRequested { // NOTE(magi): Linux does not give you an error if it can't // write the data back out so neither do we. - if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL { + if err := writeAddress(t, peer, peerLen, addr, addrLen); linuxerr.Equals(linuxerr.EINVAL, err) { return 0, err } } @@ -421,7 +422,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc switch how { case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR: default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, s.Shutdown(t, int(how)).ToError() @@ -454,7 +455,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, err } if optLen < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Call syscall implementation then copy both value and value len out. @@ -530,10 +531,10 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if optLen < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if optLen > maxOptLen { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) if _, err := t.CopyInBytes(optValAddr, buf); err != nil { @@ -612,7 +613,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -630,7 +631,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Reject flags that we don't handle yet. if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if file.Flags().NonBlocking { @@ -660,7 +661,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if vlen > linux.UIO_MAXIOV { @@ -669,7 +670,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Reject flags that we don't handle yet. if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -697,7 +698,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, err } if !ts.Valid() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration()) haveDeadline = true @@ -829,12 +830,12 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags // recvfrom and recv syscall handlers. func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) { if int(bufLen) < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Reject flags that we don't handle yet. if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -907,7 +908,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -925,7 +926,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Reject flags that we don't handle yet. if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if file.Flags().NonBlocking { @@ -945,7 +946,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if vlen > linux.UIO_MAXIOV { @@ -967,7 +968,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Reject flags that we don't handle yet. if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if file.Flags().NonBlocking { @@ -1073,7 +1074,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr hostar func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) { bl := int(bufLen) if bl < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Get socket from the file descriptor. diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 134051124..88bee61ef 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -27,7 +28,7 @@ import ( // doSplice implements a blocking splice operation. func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) { if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if opts.Length == 0 { return 0, nil @@ -125,13 +126,13 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Verify that the outfile Append flag is not set. if outFile.Flags().Append { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Verify that we have a regular infile. This is a requirement; the // same check appears in Linux (fs/splice.c:splice_direct_to_actor). if !fs.IsRegular(inFile.Dirent.Inode.StableAttr) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var ( @@ -190,7 +191,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get files. @@ -230,7 +231,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } if outOffset != 0 { if !outFile.Flags().Pwrite { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var offset int64 @@ -248,7 +249,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } if inOffset != 0 { if !inFile.Flags().Pread { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var offset int64 @@ -267,10 +268,10 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // We may not refer to the same pipe; otherwise it's a continuous loop. if inFileAttr.InodeID == outFileAttr.InodeID { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Splice data. @@ -298,7 +299,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get files. @@ -316,12 +317,12 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // All files must be pipes. if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // We may not refer to the same pipe; see above. if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // The operation is non-blocking if anything is non-blocking. diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 2338ba44b..103b13c10 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -139,13 +140,13 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall statxAddr := args[4].Pointer() if mask&linux.STATX__RESERVED != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if flags&^(linux.AT_SYMLINK_NOFOLLOW|linux.AT_EMPTY_PATH|linux.AT_STATX_SYNC_TYPE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if flags&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } path, dirPath, err := copyInPath(t, pathAddr, flags&linux.AT_EMPTY_PATH != 0) diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go index 5ebd4461f..3f0e6c02e 100644 --- a/pkg/sentry/syscalls/linux/sys_sync.go +++ b/pkg/sentry/syscalls/linux/sys_sync.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -86,13 +87,13 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel uflags := args[3].Uint() if offset < 0 || offset+nbytes < offset { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if uflags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE| linux.SYNC_FILE_RANGE_WRITE| linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if nbytes == 0 { diff --git a/pkg/sentry/syscalls/linux/sys_syslog.go b/pkg/sentry/syscalls/linux/sys_syslog.go index 40c8bb061..ba372f9e3 100644 --- a/pkg/sentry/syscalls/linux/sys_syslog.go +++ b/pkg/sentry/syscalls/linux/sys_syslog.go @@ -15,6 +15,7 @@ package linux import ( + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -40,7 +41,7 @@ func Syslog(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal switch command { case _SYSLOG_ACTION_READ_ALL: if size < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if size > logBufLen { size = logBufLen diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 3185ea527..d99dd5131 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -112,7 +113,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr host } if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } atEmptyPath := flags&linux.AT_EMPTY_PATH != 0 if !atEmptyPath && len(pathname) == 0 { @@ -260,7 +261,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error { wopts.NonCloneTasks = true wopts.CloneTasks = true default: - return syserror.EINVAL + return linuxerr.EINVAL } if options&linux.WCONTINUED != 0 { wopts.Events |= kernel.EventGroupContinue @@ -277,7 +278,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error { // wait4 waits for the given child process to exit. func wait4(t *kernel.Task, pid int, statusAddr hostarch.Addr, options int, rusageAddr hostarch.Addr) (uintptr, error) { if options&^(linux.WNOHANG|linux.WUNTRACED|linux.WCONTINUED|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } wopts := kernel.WaitOptions{ Events: kernel.EventExit | kernel.EventTraceeStop, @@ -358,10 +359,10 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal rusageAddr := args[4].Pointer() if options&^(linux.WNOHANG|linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED|linux.WNOWAIT|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if options&(linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED) == 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } wopts := kernel.WaitOptions{ Events: kernel.EventTraceeStop, @@ -374,7 +375,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case linux.P_PGID: wopts.SpecificPGID = kernel.ProcessGroupID(id) default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if err := parseCommonWaitOptions(&wopts, options); err != nil { @@ -398,7 +399,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // out the fields it would set for a successful waitid in this case // as well. if infop != 0 { - var si arch.SignalInfo + var si linux.SignalInfo _, err = si.CopyOut(t, infop) } } @@ -413,7 +414,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if infop == 0 { return 0, nil, nil } - si := arch.SignalInfo{ + si := linux.SignalInfo{ Signo: int32(linux.SIGCHLD), } si.SetPID(int32(wr.TID)) @@ -423,24 +424,24 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal s := unix.WaitStatus(wr.Status) switch { case s.Exited(): - si.Code = arch.CLD_EXITED + si.Code = linux.CLD_EXITED si.SetStatus(int32(s.ExitStatus())) case s.Signaled(): - si.Code = arch.CLD_KILLED + si.Code = linux.CLD_KILLED si.SetStatus(int32(s.Signal())) case s.CoreDump(): - si.Code = arch.CLD_DUMPED + si.Code = linux.CLD_DUMPED si.SetStatus(int32(s.Signal())) case s.Stopped(): if wr.Event == kernel.EventTraceeStop { - si.Code = arch.CLD_TRAPPED + si.Code = linux.CLD_TRAPPED si.SetStatus(int32(s.TrapCause())) } else { - si.Code = arch.CLD_STOPPED + si.Code = linux.CLD_STOPPED si.SetStatus(int32(s.StopSignal())) } case s.Continued(): - si.Code = arch.CLD_CONTINUED + si.Code = linux.CLD_CONTINUED si.SetStatus(int32(linux.SIGCONT)) default: t.Warningf("waitid got incomprehensible wait status %d", s) @@ -528,7 +529,7 @@ func SchedGetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker // in an array of "unsigned long" so the buffer needs to // be a multiple of the word size. if size&(t.Arch().Width()-1) > 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var task *kernel.Task @@ -545,7 +546,7 @@ func SchedGetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker // The buffer needs to be big enough to hold a cpumask with // all possible cpus. if size < mask.Size() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } _, err := t.CopyOutBytes(maskAddr, mask) @@ -594,7 +595,7 @@ func Setpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } tg = ot.ThreadGroup() if tg.Leader() != ot { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Setpgid only operates on child threadgroups. @@ -609,7 +610,7 @@ func Setpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if pgid == 0 { pgid = defaultPGID } else if pgid < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // If the pgid is the same as the group, then create a new one. Otherwise, @@ -712,7 +713,7 @@ func Getpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // PRIO_USER and PRIO_PGRP have no further implementation yet. return 0, nil, nil default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } @@ -754,7 +755,7 @@ func Setpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // PRIO_USER and PRIO_PGRP have no further implementation yet. return 0, nil, nil default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 83b777bbd..d75bb9c4f 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -75,7 +76,7 @@ func ClockGetres(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S } if _, err := getClock(t, clockID); err != nil { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if addr == 0 { @@ -94,12 +95,12 @@ type cpuClocker interface { func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) { if clockID < 0 { if !isValidCPUClock(clockID) { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } targetTask := targetTask(t, clockID) if targetTask == nil { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } var target cpuClocker @@ -116,7 +117,7 @@ func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) { // CPUCLOCK_SCHED is approximated by CPUCLOCK_PROF. return target.CPUClock(), nil default: - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } @@ -138,7 +139,7 @@ func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) { case linux.CLOCK_THREAD_CPUTIME_ID: return t.CPUClock(), nil default: - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } @@ -180,21 +181,21 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // // +stateify savable type clockNanosleepRestartBlock struct { - c ktime.Clock - duration time.Duration - rem hostarch.Addr + c ktime.Clock + end ktime.Time + rem hostarch.Addr } // Restart implements kernel.SyscallRestartBlock.Restart. func (n *clockNanosleepRestartBlock) Restart(t *kernel.Task) (uintptr, error) { - return 0, clockNanosleepFor(t, n.c, n.duration, n.rem) + return 0, clockNanosleepUntil(t, n.c, n.end, n.rem, true) } // clockNanosleepUntil blocks until a specified time. // // If blocking is interrupted, the syscall is restarted with the original // arguments. -func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error { +func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem hostarch.Addr, needRestartBlock bool) error { notifier, tchan := ktime.NewChannelNotifier() timer := ktime.NewTimer(c, notifier) @@ -202,43 +203,22 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error timer.Swap(ktime.Setting{ Period: 0, Enabled: true, - Next: ktime.FromTimespec(ts), + Next: end, }) err := t.BlockWithTimer(nil, tchan) timer.Destroy() - // Did we just block until the timeout happened? - if err == syserror.ETIMEDOUT { - return nil - } - - return syserror.ConvertIntr(err, syserror.ERESTARTNOHAND) -} - -// clockNanosleepFor blocks for a specified duration. -// -// If blocking is interrupted, the syscall is restarted with the remaining -// duration timeout. -func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hostarch.Addr) error { - timer, start, tchan := ktime.After(c, dur) - - err := t.BlockWithTimer(nil, tchan) - - after := c.Now() - - timer.Destroy() - - switch err { - case syserror.ETIMEDOUT: + switch { + case linuxerr.Equals(linuxerr.ETIMEDOUT, err): // Slept for entire timeout. return nil - case syserror.ErrInterrupted: + case err == syserror.ErrInterrupted: // Interrupted. - remaining := dur - after.Sub(start) - if remaining < 0 { - remaining = time.Duration(0) + remaining := end.Sub(c.Now()) + if remaining <= 0 { + return nil } // Copy out remaining time. @@ -248,14 +228,16 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hos return err } } - - // Arrange for a restart with the remaining duration. - t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{ - c: c, - duration: remaining, - rem: rem, - }) - return syserror.ERESTART_RESTARTBLOCK + if needRestartBlock { + // Arrange for a restart with the remaining duration. + t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{ + c: c, + end: end, + rem: rem, + }) + return syserror.ERESTART_RESTARTBLOCK + } + return syserror.ERESTARTNOHAND default: panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err)) } @@ -272,13 +254,14 @@ func Nanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } if !ts.Valid() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Just like linux, we cap the timeout with the max number that int64 can // represent which is roughly 292 years. dur := time.Duration(ts.ToNsecCapped()) * time.Nanosecond - return 0, nil, clockNanosleepFor(t, t.Kernel().MonotonicClock(), dur, rem) + c := t.Kernel().MonotonicClock() + return 0, nil, clockNanosleepUntil(t, c, c.Now().Add(dur), rem, true) } // ClockNanosleep implements linux syscall clock_nanosleep(2). @@ -294,7 +277,7 @@ func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } if !req.Valid() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Only allow clock constants also allowed by Linux. @@ -302,7 +285,7 @@ func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne if clockID != linux.CLOCK_REALTIME && clockID != linux.CLOCK_MONOTONIC && clockID != linux.CLOCK_PROCESS_CPUTIME_ID { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } @@ -312,11 +295,11 @@ func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } if flags&linux.TIMER_ABSTIME != 0 { - return 0, nil, clockNanosleepUntil(t, c, req) + return 0, nil, clockNanosleepUntil(t, c, ktime.FromTimespec(req), 0, false) } dur := time.Duration(req.ToNsecCapped()) * time.Nanosecond - return 0, nil, clockNanosleepFor(t, c, dur, rem) + return 0, nil, clockNanosleepUntil(t, c, c.Now().Add(dur), rem, true) } // Gettimeofday implements linux syscall gettimeofday(2). diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go index cadd9d348..a8e88b814 100644 --- a/pkg/sentry/syscalls/linux/sys_timerfd.go +++ b/pkg/sentry/syscalls/linux/sys_timerfd.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/timerfd" @@ -30,7 +31,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel flags := args[1].Int() if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var c ktime.Clock @@ -40,7 +41,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME: c = t.Kernel().MonotonicClock() default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } f := timerfd.NewFile(t, c) defer f.DecRef(t) @@ -66,7 +67,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne oldValAddr := args[3].Pointer() if flags&^(linux.TFD_TIMER_ABSTIME) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } f := t.GetFile(fd) @@ -77,7 +78,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tf, ok := f.FileOperations.(*timerfd.TimerOperations) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var newVal linux.Itimerspec @@ -111,7 +112,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tf, ok := f.FileOperations.(*timerfd.TimerOperations) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } tm, s := tf.GetTime() diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go index 6ddd30d5c..32272c267 100644 --- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go @@ -18,6 +18,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -48,7 +49,7 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys t.Kernel().EmitUnimplementedEvent(t) fallthrough default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go index 66c5974f5..7fffb189e 100644 --- a/pkg/sentry/syscalls/linux/sys_utsname.go +++ b/pkg/sentry/syscalls/linux/sys_utsname.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -60,7 +61,7 @@ func Setdomainname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.EPERM } if size < 0 || size > linux.UTSLen { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } name, err := t.CopyInString(nameAddr, int(size)) @@ -82,7 +83,7 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EPERM } if size < 0 || size > linux.UTSLen { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } name := make([]byte, size) diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 95bfe6606..998b5fde6 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -58,7 +59,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the source of the write. @@ -89,7 +90,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is writing at an offset supported? @@ -105,7 +106,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the source of the write. @@ -166,7 +167,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Check that the offset is legitimate. if offset < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is writing at an offset supported? @@ -219,7 +220,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Check that the offset is legitimate. if offset < -1 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is writing at an offset supported? @@ -301,7 +302,7 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go index 28ad6a60e..da6651062 100644 --- a/pkg/sentry/syscalls/linux/sys_xattr.go +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -18,6 +18,7 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -182,7 +183,7 @@ func setXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink // setXattr implements setxattr(2) from the given *fs.Dirent. func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, size uint64, flags uint32) error { if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } name, err := copyInXattrName(t, nameAddr) @@ -195,7 +196,7 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, s } if size > linux.XATTR_SIZE_MAX { - return syserror.E2BIG + return linuxerr.E2BIG } buf := make([]byte, size) if _, err := t.CopyInBytes(valueAddr, buf); err != nil { @@ -217,7 +218,7 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, s func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) { name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) if err != nil { - if err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { return "", syserror.ERANGE } return "", err @@ -333,7 +334,7 @@ func listXattr(t *kernel.Task, d *fs.Dirent, addr hostarch.Addr, size uint64) (i listSize := xattrListSize(xattrs) if listSize > linux.XATTR_SIZE_MAX { - return 0, syserror.E2BIG + return 0, linuxerr.E2BIG } if uint64(listSize) > requestedSize { return 0, syserror.ERANGE diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go index 3edc922eb..b327e27d6 100644 --- a/pkg/sentry/syscalls/linux/timespec.go +++ b/pkg/sentry/syscalls/linux/timespec.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -103,7 +104,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time. return 0, err } if !timespec.Valid() { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } timeout = time.Duration(timespec.ToNsecCapped()) } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 5ce0bc714..a73f096ff 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -41,6 +41,7 @@ go_library( "//pkg/abi/linux", "//pkg/bits", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fspath", "//pkg/gohacks", "//pkg/hostarch", diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index fd1863ef3..d81df637f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -17,6 +17,8 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" @@ -26,8 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" ) // IoSubmit implements linux syscall io_submit(2). @@ -37,7 +37,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc addr := args[2].Pointer() if nrEvents < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } for i := int32(0); i < nrEvents; i++ { @@ -90,7 +90,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // submitCallback processes a single callback. func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr hostarch.Addr) error { if cb.Reserved2 != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } fd := t.GetFileVFS2(cb.FD) @@ -110,7 +110,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host // Check that it is an eventfd. if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok { - return syserror.EINVAL + return linuxerr.EINVAL } } @@ -123,14 +123,14 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host switch cb.OpCode { case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV: if cb.Offset < 0 { - return syserror.EINVAL + return linuxerr.EINVAL } } // Prepare the request. aioCtx, ok := t.MemoryManager().LookupAIOContext(t, id) if !ok { - return syserror.EINVAL + return linuxerr.EINVAL } if err := aioCtx.Prepare(); err != nil { return err @@ -200,7 +200,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) bytes := int(cb.Bytes) if bytes < 0 { // Linux also requires that this field fit in ssize_t. - return usermem.IOSequence{}, syserror.EINVAL + return usermem.IOSequence{}, linuxerr.EINVAL } // Since this I/O will be asynchronous with respect to t's task goroutine, @@ -222,6 +222,6 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) default: // Not a supported command. - return usermem.IOSequence{}, syserror.EINVAL + return usermem.IOSequence{}, linuxerr.EINVAL } } diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index 047d955b6..d3bb3a3e1 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -34,7 +35,7 @@ var sizeofEpollEvent = (*linux.EpollEvent)(nil).SizeBytes() func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { flags := args[0].Int() if flags&^linux.EPOLL_CLOEXEC != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file, err := t.Kernel().VFS().NewEpollInstanceFD(t) @@ -59,7 +60,7 @@ func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // "Since Linux 2.6.8, the size argument is ignored, but must be greater // than zero" - epoll_create(2) if size <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file, err := t.Kernel().VFS().NewEpollInstanceFD(t) @@ -89,7 +90,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc defer epfile.DecRef(t) ep, ok := epfile.Impl().(*vfs.EpollInstance) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFileVFS2(fd) if file == nil { @@ -97,7 +98,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } defer file.DecRef(t) if epfile == file { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var event linux.EpollEvent @@ -115,14 +116,14 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } return 0, nil, ep.ModifyInterest(file, fd, event) default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } epfile := t.GetFileVFS2(epfd) @@ -132,7 +133,7 @@ func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents i defer epfile.DecRef(t) ep, ok := epfile.Impl().(*vfs.EpollInstance) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Allocate space for a few events on the stack for the common case in @@ -174,7 +175,7 @@ func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents i haveDeadline = true } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = nil } return 0, nil, err diff --git a/pkg/sentry/syscalls/linux/vfs2/eventfd.go b/pkg/sentry/syscalls/linux/vfs2/eventfd.go index 807f909da..0dcf1fbff 100644 --- a/pkg/sentry/syscalls/linux/vfs2/eventfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/eventfd.go @@ -16,10 +16,10 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) // Eventfd2 implements linux syscall eventfd2(2). @@ -29,7 +29,7 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC) if flags & ^allOps != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } vfsObj := t.Kernel().VFS() diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index 3315398a4..7b1e1da78 100644 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -16,7 +16,9 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -24,8 +26,6 @@ import ( slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Execve implements linux syscall execve(2). @@ -48,7 +48,7 @@ func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr hostarch.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 1a31898e8..ea34ff471 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" @@ -86,7 +87,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC flags := args[2].Uint() if oldfd == newfd { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return dup3(t, oldfd, newfd, flags) @@ -94,7 +95,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) { if flags&^linux.O_CLOEXEC != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFileVFS2(oldfd) @@ -169,7 +170,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if who < 0 { // Check for overflow before flipping the sign. if who-1 > who { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ownerType = linux.F_OWNER_PGRP who = -who @@ -232,7 +233,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } @@ -269,7 +270,7 @@ func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP: // Acceptable type. default: - return syserror.EINVAL + return linuxerr.EINVAL } a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync) @@ -301,7 +302,7 @@ func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, a.SetOwnerProcessGroup(t, pg) return nil default: - return syserror.EINVAL + return linuxerr.EINVAL } } @@ -319,7 +320,7 @@ func posixTestLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDes case linux.F_WRLCK: typ = lock.WriteLock default: - return syserror.EINVAL + return linuxerr.EINVAL } r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence) if err != nil { @@ -382,7 +383,7 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip return file.UnlockPOSIX(t, t.FDTable(), r) default: - return syserror.EINVAL + return linuxerr.EINVAL } } @@ -395,7 +396,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // Note: offset is allowed to be negative. if length < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFileVFS2(fd) @@ -421,7 +422,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys case linux.POSIX_FADV_DONTNEED: case linux.POSIX_FADV_NOREUSE: default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Sure, whatever. diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index 36aa1d3ae..534355237 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -16,12 +16,12 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Link implements Linux syscall link(2). @@ -43,7 +43,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal func linkat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd int32, newpathAddr hostarch.Addr, flags int32) error { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) { return syserror.ENOENT @@ -290,7 +290,7 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc flags := args[2].Int() if flags&^linux.AT_REMOVEDIR != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if flags&linux.AT_REMOVEDIR != 0 { diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go index b41a3056a..8ace31af3 100644 --- a/pkg/sentry/syscalls/linux/vfs2/getdents.go +++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go @@ -17,13 +17,13 @@ package vfs2 import ( "fmt" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Getdents implements Linux syscall getdents(2). @@ -100,7 +100,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name) size = (size + 7) &^ 7 // round up to multiple of 8 if size > cb.remaining { - return syserror.EINVAL + return linuxerr.EINVAL } buf = cb.t.CopyScratchBuffer(size) hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino) @@ -134,7 +134,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error { size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name) size = (size + 7) &^ 7 // round up to multiple of sizeof(long) if size > cb.remaining { - return syserror.EINVAL + return linuxerr.EINVAL } buf = cb.t.CopyScratchBuffer(size) hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino) diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go index 11753d8e5..7a2e9e75d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/inotify.go +++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -28,7 +29,7 @@ const allFlags = linux.IN_NONBLOCK | linux.IN_CLOEXEC func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { flags := args[0].Int() if flags&^allFlags != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ino, err := vfs.NewInotifyFD(t, t.Kernel().VFS(), uint32(flags)) @@ -67,7 +68,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, if !ok { // Not an inotify fd. f.DecRef(t) - return nil, nil, syserror.EINVAL + return nil, nil, linuxerr.EINVAL } return ino, f, nil @@ -82,7 +83,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern // "EINVAL: The given event mask contains no valid events." // -- inotify_add_watch(2) if mask&linux.ALL_INOTIFY_BITS == 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link." diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index c7c3fed57..9852e3fe4 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -99,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if who < 0 { // Check for overflow before flipping the sign. if who-1 > who { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } ownerType = linux.F_OWNER_PGRP who = -who diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go index d1452a04d..80cb3ba09 100644 --- a/pkg/sentry/syscalls/linux/vfs2/lock.go +++ b/pkg/sentry/syscalls/linux/vfs2/lock.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -57,7 +58,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } default: // flock(2): EINVAL operation is invalid. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go index c4c0f9e0a..70c2cf5a5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/memfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go @@ -16,10 +16,10 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/syserror" ) const ( @@ -35,7 +35,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if flags&^memfdAllFlags != 0 { // Unknown bits in flags. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } allowSeals := flags&linux.MFD_ALLOW_SEALING != 0 diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go index c961545f6..db8d59899 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mmap.go +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -16,13 +16,13 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Mmap implements Linux syscall mmap(2). @@ -38,7 +38,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Require exactly one of MAP_PRIVATE and MAP_SHARED. if private == shared { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } opts := memmap.MMapOpts{ diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index dd93430e2..667e48744 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -16,12 +16,12 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Mount implements Linux syscall mount(2). @@ -84,7 +84,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // unknown or unsupported flags are passed. Since we don't implement // everything, we fail explicitly on flags that are unimplemented. if flags&(unsupportedOps|unsupportedFlags) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var opts vfs.MountOptions @@ -130,7 +130,7 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE if flags&unsupported != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } path, err := copyInPath(t, addr) diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index c6fc1954c..07a89cf4e 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -16,14 +16,13 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Pipe implements Linux syscall pipe(2). @@ -41,7 +40,7 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall func pipe2(t *kernel.Task, addr hostarch.Addr, flags int32) error { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } r, w, err := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) if err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index a69c80edd..ea95dd78c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -132,7 +133,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // Wait for a notification. timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout) if err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = nil } return timeout, 0, err @@ -161,7 +162,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time. // copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max. func copyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) { if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } pfd := make([]linux.PollFD, nfds) @@ -221,7 +222,7 @@ func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialB func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Calculate the size of the fd sets (one bit per fd). @@ -410,7 +411,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) { func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) { remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout) // On an interrupt poll(2) is restarted with the remaining timeout. - if err == syserror.EINTR { + if linuxerr.Equals(linuxerr.EINTR, err) { t.SetSyscallRestartBlock(&pollRestartBlock{ pfdAddr: pfdAddr, nfds: nfds, @@ -462,7 +463,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -484,7 +485,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err } if timeval.Sec < 0 || timeval.Usec < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } timeout = time.Duration(timeval.ToNsecCapped()) } @@ -492,7 +493,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -539,7 +540,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout) copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. - if err == syserror.EINTR && copyErr == nil { + if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil { err = syserror.ERESTARTNOHAND } return n, nil, err @@ -561,7 +562,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time. return 0, err } if !timespec.Valid() { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } timeout = time.Duration(timespec.ToNsecCapped()) } @@ -573,7 +574,7 @@ func setTempSignalSet(t *kernel.Task, maskAddr hostarch.Addr, maskSize uint) err return nil } if maskSize != linux.SignalSetSize { - return syserror.EINVAL + return linuxerr.EINVAL } var mask linux.SignalSet if _, err := mask.CopyIn(t, maskAddr); err != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index b863d7b84..3e515f6fd 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -49,7 +50,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the destination of the read. @@ -120,7 +121,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -146,13 +147,13 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the destination of the read. @@ -183,7 +184,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Check that the offset is legitimate. if offset < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the destination of the read. @@ -221,7 +222,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Check that the offset is legitimate. if offset < -1 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the destination of the read. @@ -275,7 +276,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -300,7 +301,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the source of the write. @@ -371,7 +372,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -396,13 +397,13 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Check that the size is legitimate. si := int(size) if si < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the source of the write. @@ -433,7 +434,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Check that the offset is legitimate. if offset < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the source of the write. @@ -471,7 +472,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Check that the offset is legitimate. if offset < -1 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get the source of the write. @@ -525,7 +526,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o // Wait for a notification that we should retry. if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { + if linuxerr.Equals(linuxerr.ETIMEDOUT, err) { err = syserror.ErrWouldBlock } break @@ -587,16 +588,16 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys // Check that the size is valid. if int(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Check that the offset is legitimate and does not overflow. if offset < 0 || offset+int64(size) < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Return EINVAL; if the underlying file type does not support readahead, // then Linux will return EINVAL to indicate as much. In the future, we // may extend this function to actually support readahead hints. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index c6330c21a..0fbafd6f6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -16,15 +16,15 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX @@ -105,7 +105,7 @@ func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func fchownat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, owner, group, flags int32) error { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } path, err := copyInPath(t, pathAddr) @@ -126,7 +126,7 @@ func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vf if owner != -1 { kuid := userns.MapToKUID(auth.UID(owner)) if !kuid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } opts.Stat.Mask |= linux.STATX_UID opts.Stat.UID = uint32(kuid) @@ -134,7 +134,7 @@ func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vf if group != -1 { kgid := userns.MapToKGID(auth.GID(group)) if !kgid.Ok() { - return syserror.EINVAL + return linuxerr.EINVAL } opts.Stat.Mask |= linux.STATX_GID opts.Stat.GID = uint32(kgid) @@ -167,7 +167,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc length := args[1].Int64() if length < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } path, err := copyInPath(t, addr) @@ -191,7 +191,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys length := args[1].Int64() if length < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFileVFS2(fd) @@ -201,7 +201,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer file.DecRef(t) if !file.IsWritable() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } err := file.SetStat(t, vfs.SetStatOptions{ @@ -233,7 +233,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.ENOTSUP } if offset < 0 || length <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } size := offset + length @@ -242,9 +242,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } limit := limits.FromContext(t).Get(limits.FileSize).Cur if uint64(size) >= limit { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } @@ -340,7 +340,7 @@ func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr hostarch.Addr, op return err } if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 { - return syserror.EINVAL + return linuxerr.EINVAL } opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME opts.Stat.Atime = linux.StatxTimestamp{ @@ -372,7 +372,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // "If filename is NULL and dfd refers to an open file, then operate on the @@ -405,7 +405,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr hostarch.Addr, o } if times[0].Nsec != linux.UTIME_OMIT { if times[0].Nsec != linux.UTIME_NOW && (times[0].Nsec < 0 || times[0].Nsec > 999999999) { - return syserror.EINVAL + return linuxerr.EINVAL } opts.Stat.Mask |= linux.STATX_ATIME opts.Stat.Atime = linux.StatxTimestamp{ @@ -415,7 +415,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr hostarch.Addr, o } if times[1].Nsec != linux.UTIME_OMIT { if times[1].Nsec != linux.UTIME_NOW && (times[1].Nsec < 0 || times[1].Nsec > 999999999) { - return syserror.EINVAL + return linuxerr.EINVAL } opts.Stat.Mask |= linux.STATX_MTIME opts.Stat.Mtime = linux.StatxTimestamp{ diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go index 6163da103..8b219cba7 100644 --- a/pkg/sentry/syscalls/linux/vfs2/signal.go +++ b/pkg/sentry/syscalls/linux/vfs2/signal.go @@ -16,13 +16,13 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/signalfd" "gvisor.dev/gvisor/pkg/sentry/kernel" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // sharedSignalfd is shared between the two calls. @@ -35,7 +35,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u // Always check for valid flags, even if not creating. if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Is this a change to an existing signalfd? @@ -55,7 +55,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u } // Not a signalfd. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } fileFlags := uint32(linux.O_RDWR) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 69f69e3af..c78c7d951 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -18,6 +18,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -117,7 +118,7 @@ type multipleMessageHeader64 struct { // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) { if addrlen > maxAddrLen { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } addrBuf := make([]byte, addrlen) @@ -139,7 +140,7 @@ func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr h } if int32(bufLen) < 0 { - return syserror.EINVAL + return linuxerr.EINVAL } // Write the length unconditionally. @@ -173,7 +174,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Check and initialize the flags. if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Create the new socket. @@ -206,7 +207,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // Check and initialize the flags. if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Create the socket pair. @@ -281,7 +282,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) { // Check that no unsupported flags are passed in. if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -309,7 +310,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, if peerRequested { // NOTE(magi): Linux does not give you an error if it can't // write the data back out so neither do we. - if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL { + if err := writeAddress(t, peer, peerLen, addr, addrLen); linuxerr.Equals(linuxerr.EINVAL, err) { return 0, err } } @@ -425,7 +426,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc switch how { case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR: default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } return 0, nil, s.Shutdown(t, int(how)).ToError() @@ -458,7 +459,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, err } if optLen < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Call syscall implementation then copy both value and value len out. @@ -534,10 +535,10 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if optLen < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if optLen > maxOptLen { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) if _, err := t.CopyInBytes(optValAddr, buf); err != nil { @@ -616,7 +617,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -634,7 +635,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Reject flags that we don't handle yet. if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { @@ -664,7 +665,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if vlen > linux.UIO_MAXIOV { @@ -673,7 +674,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Reject flags that we don't handle yet. if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -701,7 +702,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, err } if !ts.Valid() { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration()) haveDeadline = true @@ -833,12 +834,12 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, fl // recvfrom and recv syscall handlers. func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) { if int(bufLen) < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Reject flags that we don't handle yet. if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -911,7 +912,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get socket from the file descriptor. @@ -929,7 +930,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Reject flags that we don't handle yet. if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { @@ -949,7 +950,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if t.Arch().Width() != 8 { // We only handle 64-bit for now. - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if vlen > linux.UIO_MAXIOV { @@ -971,7 +972,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Reject flags that we don't handle yet. if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 { @@ -1077,7 +1078,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) { bl := int(bufLen) if bl < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Get socket from the file descriptor. diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 19e175203..6ddc72999 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -18,6 +18,7 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -46,12 +47,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal count = int64(kernel.MAX_RW_COUNT) } if count < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get file descriptions. @@ -82,7 +83,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) if !inIsPipe && !outIsPipe { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Copy in offsets. @@ -92,13 +93,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ESPIPE } if inFile.Options().DenyPRead { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if _, err := primitive.CopyInt64In(t, inOffsetPtr, &inOffset); err != nil { return 0, nil, err } if inOffset < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } outOffset := int64(-1) @@ -107,13 +108,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ESPIPE } if outFile.Options().DenyPWrite { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if _, err := primitive.CopyInt64In(t, outOffsetPtr, &outOffset); err != nil { return 0, nil, err } if outOffset < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } @@ -189,12 +190,12 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo count = int64(kernel.MAX_RW_COUNT) } if count < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Get file descriptions. @@ -225,7 +226,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD) outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) if !inIsPipe || !outIsPipe { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Copy data. @@ -288,7 +289,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Verify that the outFile Append flag is not set. if outFile.StatusFlags()&linux.O_APPEND != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Verify that inFile is a regular file or block device. This is a @@ -298,7 +299,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return 0, nil, err } else if stat.Mask&linux.STATX_TYPE == 0 || (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Copy offset if it exists. @@ -314,16 +315,16 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc offset = int64(offsetP) if offset < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if offset+count < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } } // Validate count. This must come after offset checks. if count < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if count == 0 { return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index 69e77fa99..8a22ed8a5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -17,15 +17,15 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // Stat implements Linux syscall stat(2). @@ -53,7 +53,7 @@ func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr hostarch.Addr, flags int32) error { if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } opts := vfs.StatOptions{ @@ -156,15 +156,15 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall statxAddr := args[4].Pointer() if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW|linux.AT_STATX_SYNC_TYPE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Make sure that only one sync type option is set. syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE) if syncType != 0 && !bits.IsPowerOfTwo32(syncType) { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if mask&linux.STATX__RESERVED != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } opts := vfs.StatOptions{ @@ -272,7 +272,7 @@ func accessAt(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, mode uint) er // Sanity check the mode. if mode&^(rOK|wOK|xOK) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } path, err := copyInPath(t, pathAddr) @@ -315,7 +315,7 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr hostarch.Addr, size uint) (uintptr, *kernel.SyscallControl, error) { if int(size) <= 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } path, err := copyInPath(t, pathAddr) diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 1f8a5878c..9344a81ce 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -71,10 +72,10 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel // Check for negative values and overflow. if offset < 0 || offset+nbytes < 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFileVFS2(fd) diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go index 250870c03..0794330c6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/timerfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -29,7 +30,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel flags := args[1].Int() if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } // Timerfds aren't writable per se (their implementation of Write just @@ -47,7 +48,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME: clock = t.Kernel().MonotonicClock() default: - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } vfsObj := t.Kernel().VFS() file, err := timerfd.New(t, vfsObj, clock, fileFlags) @@ -72,7 +73,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne oldValAddr := args[3].Pointer() if flags&^(linux.TFD_TIMER_ABSTIME) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFileVFS2(fd) @@ -83,7 +84,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tfd, ok := file.Impl().(*timerfd.TimerFileDescription) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } var newVal linux.Itimerspec @@ -117,7 +118,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tfd, ok := file.Impl().(*timerfd.TimerFileDescription) if !ok { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } tm, s := tfd.GetTime() diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index c261050c6..33209a8d0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -18,6 +18,7 @@ import ( "bytes" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -178,7 +179,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli flags := args[4].Int() if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } path, err := copyInPath(t, pathAddr) @@ -216,7 +217,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys flags := args[4].Int() if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { - return 0, nil, syserror.EINVAL + return 0, nil, linuxerr.EINVAL } file := t.GetFileVFS2(fd) @@ -295,7 +296,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) { name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) if err != nil { - if err == syserror.ENAMETOOLONG { + if linuxerr.Equals(linuxerr.ENAMETOOLONG, err) { return "", syserror.ERANGE } return "", err @@ -321,7 +322,7 @@ func copyOutXattrNameList(t *kernel.Task, listAddr hostarch.Addr, size uint, nam } if buf.Len() > int(size) { if size >= linux.XATTR_LIST_MAX { - return 0, syserror.E2BIG + return 0, linuxerr.E2BIG } return 0, syserror.ERANGE } @@ -330,7 +331,7 @@ func copyOutXattrNameList(t *kernel.Task, listAddr hostarch.Addr, size uint, nam func copyInXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint) (string, error) { if size > linux.XATTR_SIZE_MAX { - return "", syserror.E2BIG + return "", linuxerr.E2BIG } buf := make([]byte, size) if _, err := t.CopyInBytes(valueAddr, buf); err != nil { @@ -349,7 +350,7 @@ func copyOutXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint, value } if len(value) > int(size) { if size >= linux.XATTR_SIZE_MAX { - return 0, syserror.E2BIG + return 0, linuxerr.E2BIG } return 0, syserror.ERANGE } diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 1f617ca8f..36d999c47 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//pkg/sync:generic_seqatomic", + template = "//pkg/sync/seqatomic:generic_seqatomic", types = { "Value": "Parameters", }, @@ -25,6 +25,8 @@ go_library( "muldiv_arm64.s", "parameters.go", "sampler.go", + "sampler_amd64.go", + "sampler_arm64.go", "sampler_unsafe.go", "seqatomic_parameters_unsafe.go", "tsc_amd64.s", @@ -32,6 +34,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/errors/linuxerr", "//pkg/gohacks", "//pkg/log", "//pkg/metric", diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index 39bf1e0de..eed74f6bd 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -19,10 +19,10 @@ package time import ( "time" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/syserror" ) // CalibratedClock implements a clock that tracks a reference clock. @@ -259,6 +259,6 @@ func (c *CalibratedClocks) GetTime(id ClockID) (int64, error) { case Realtime: return c.realtime.GetTime() default: - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } } diff --git a/pkg/sentry/time/sampler.go b/pkg/sentry/time/sampler.go index 4ac9c4474..24a47f5d5 100644 --- a/pkg/sentry/time/sampler.go +++ b/pkg/sentry/time/sampler.go @@ -21,13 +21,6 @@ import ( ) const ( - // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. - // It is further refined as syscalls are made. - defaultOverheadCycles = 1 * 1000 - - // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles. - maxOverheadCycles = 100 * defaultOverheadCycles - // maxSampleLoops is the maximum number of times to try to get a clock sample // under the expected overhead. maxSampleLoops = 5 diff --git a/pkg/sentry/time/sampler_amd64.go b/pkg/sentry/time/sampler_amd64.go new file mode 100644 index 000000000..9f1b4b2fb --- /dev/null +++ b/pkg/sentry/time/sampler_amd64.go @@ -0,0 +1,26 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build amd64 + +package time + +const ( + // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. + // It is further refined as syscalls are made. + defaultOverheadCycles = 1 * 1000 + + // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles. + maxOverheadCycles = 100 * defaultOverheadCycles +) diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go new file mode 100644 index 000000000..4c8d33ae4 --- /dev/null +++ b/pkg/sentry/time/sampler_arm64.go @@ -0,0 +1,42 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build arm64 + +package time + +// getCNTFRQ get ARM counter-timer frequency +func getCNTFRQ() TSCValue + +// getDefaultArchOverheadCycles get default OverheadCycles based on +// ARM counter-timer frequency. Usually ARM counter-timer frequency +// is range from 1-50Mhz which is much less than that on x86, so we +// calibrate defaultOverheadCycles for ARM. +func getDefaultArchOverheadCycles() TSCValue { + // estimated the clock frequency on x86 is 1Ghz. + // 1Ghz devided by counter-timer frequency of ARM to get + // frqRatio. defaultOverheadCycles of ARM equals to that on + // x86 devided by frqRatio + cntfrq := getCNTFRQ() + frqRatio := 1000000000 / cntfrq + overheadCycles := (1 * 1000) / frqRatio + return overheadCycles +} + +// defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. +// It is further refined as syscalls are made. +var defaultOverheadCycles = getDefaultArchOverheadCycles() + +// maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles. +var maxOverheadCycles = 100 * defaultOverheadCycles diff --git a/pkg/sentry/time/tsc_arm64.s b/pkg/sentry/time/tsc_arm64.s index da9fa4112..711349fa1 100644 --- a/pkg/sentry/time/tsc_arm64.s +++ b/pkg/sentry/time/tsc_arm64.s @@ -20,3 +20,9 @@ TEXT ·Rdtsc(SB),NOSPLIT,$0-8 WORD $0xd53be040 //MRS CNTVCT_EL0, R0 MOVD R0, ret+0(FP) RET + +TEXT ·getCNTFRQ(SB),NOSPLIT,$0-8 + // Get the virtual counter frequency. + WORD $0xd53be000 //MRS CNTFRQ_EL0, R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index 581862ee2..e7073ec87 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -132,7 +132,7 @@ func Init() error { // always be the case for a newly mapped page from /dev/shm. If we obtain // the shared memory through some other means in the future, we may have to // explicitly zero the page. - mmap, err := unix.Mmap(int(file.Fd()), 0, int(RTMemoryStatsSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + mmap, err := memutil.MapFile(0, RTMemoryStatsSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, file.Fd(), 0) if err != nil { return fmt.Errorf("error mapping usage file: %v", err) } diff --git a/pkg/sentry/usage/memory_unsafe.go b/pkg/sentry/usage/memory_unsafe.go index 9e0014ca0..bc1531b91 100644 --- a/pkg/sentry/usage/memory_unsafe.go +++ b/pkg/sentry/usage/memory_unsafe.go @@ -21,7 +21,7 @@ import ( // RTMemoryStatsSize is the size of the RTMemoryStats struct. var RTMemoryStatsSize = unsafe.Sizeof(RTMemoryStats{}) -// RTMemoryStatsPointer casts the address of the byte slice into a RTMemoryStats pointer. -func RTMemoryStatsPointer(b []byte) *RTMemoryStats { - return (*RTMemoryStats)(unsafe.Pointer(&b[0])) +// RTMemoryStatsPointer casts addr to a RTMemoryStats pointer. +func RTMemoryStatsPointer(addr uintptr) *RTMemoryStats { + return (*RTMemoryStats)(unsafe.Pointer(addr)) } diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index ac60fe8bf..a2032162d 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -95,6 +95,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdnotifier", "//pkg/fspath", @@ -133,6 +134,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/sentry/contexttest", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index f48817132..8b3612200 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -164,7 +165,7 @@ func (fs *anonFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (st if !rp.Done() { return "", syserror.ENOTDIR } - return "", syserror.EINVAL + return "", linuxerr.EINVAL } // RenameAt implements FilesystemImpl.RenameAt. diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index ef8d8a813..63ee0aab3 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsmetric" @@ -273,7 +274,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede } } if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO { - return syserror.EINVAL + return linuxerr.EINVAL } // TODO(gvisor.dev/issue/1035): FileDescriptionImpl.SetOAsync()? const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK @@ -708,8 +709,8 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string return names, err } names, err := fd.impl.ListXattr(ctx, size) - if err == syserror.ENOTSUP { - // Linux doesn't actually return ENOTSUP in this case; instead, + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { + // Linux doesn't actually return EOPNOTSUPP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by default // don't exist. @@ -873,7 +874,7 @@ func (fd *FileDescription) ComputeLockRange(ctx context.Context, start uint64, l } off = int64(stat.Size) default: - return lock.LockRange{}, syserror.EINVAL + return lock.LockRange{}, linuxerr.EINVAL } return lock.ComputeRange(int64(start), int64(length), off) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 2b6f47b4b..cffb46aab 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -88,25 +89,25 @@ func (FileDescriptionDefaultImpl) EventUnregister(e *waiter.Entry) { // PRead implements FileDescriptionImpl.PRead analogously to // file_operations::read == file_operations::read_iter == NULL in Linux. func (FileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Read implements FileDescriptionImpl.Read analogously to // file_operations::read == file_operations::read_iter == NULL in Linux. func (FileDescriptionDefaultImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // PWrite implements FileDescriptionImpl.PWrite analogously to // file_operations::write == file_operations::write_iter == NULL in Linux. func (FileDescriptionDefaultImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Write implements FileDescriptionImpl.Write analogously to // file_operations::write == file_operations::write_iter == NULL in Linux. func (FileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // IterDirents implements FileDescriptionImpl.IterDirents analogously to @@ -125,7 +126,7 @@ func (FileDescriptionDefaultImpl) Seek(ctx context.Context, offset int64, whence // Sync implements FileDescriptionImpl.Sync analogously to // file_operations::fsync == NULL in Linux. func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error { - return syserror.EINVAL + return linuxerr.EINVAL } // ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to @@ -333,10 +334,10 @@ func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int6 offset += fd.off default: // fs/seq_file:seq_lseek() rejects SEEK_END etc. - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } if offset != fd.lastRead { // Regenerate the file's contents immediately. Compare diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 1cd607c0a..566ad856a 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -155,10 +156,10 @@ func TestGenCountFD(t *testing.T) { } // Write and PWrite fails. - if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO { + if _, err := fd.Write(ctx, ioseq, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } - if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO { + if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) { t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } } @@ -215,10 +216,10 @@ func TestWritable(t *testing.T) { if n, err := fd.Seek(ctx, 1, linux.SEEK_SET); n != 0 && err != nil { t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err) } - if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && err != syserror.EINVAL { + if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("Write: got err (%v, %v), wanted (0, EINVAL)", n, err) } - if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && err != syserror.EINVAL { + if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) { t.Errorf("PWrite: got err (%v, %v), wanted (0, EINVAL)", n, err) } } diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 49d29e20b..a7655bbb5 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/uniqueid" @@ -98,7 +99,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) // O_CLOEXEC affects file descriptors, so it must be handled outside of vfs. flags &^= linux.O_CLOEXEC if flags&^linux.O_NONBLOCK != 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } id := uniqueid.GlobalFromContext(ctx) @@ -200,7 +201,7 @@ func (*Inotify) Write(ctx context.Context, src usermem.IOSequence, opts WriteOpt // Read implements FileDescriptionImpl.Read. func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { if dst.NumBytes() < inotifyEventBaseSize { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } i.evMu.Lock() @@ -226,7 +227,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt // write some events out. return writeLen, nil } - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } // Linux always dequeues an available event as long as there's enough @@ -360,7 +361,7 @@ func (i *Inotify) RmWatch(ctx context.Context, wd int32) error { w, ok := i.watches[wd] if !ok { i.mu.Unlock() - return syserror.EINVAL + return linuxerr.EINVAL } // Remove the watch from this instance. diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD index d8c4d27b9..ea82f4987 100644 --- a/pkg/sentry/vfs/memxattr/BUILD +++ b/pkg/sentry/vfs/memxattr/BUILD @@ -8,6 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go index 638b5d830..9b7953fa3 100644 --- a/pkg/sentry/vfs/memxattr/xattr.go +++ b/pkg/sentry/vfs/memxattr/xattr.go @@ -17,7 +17,10 @@ package memxattr import ( + "strings" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -26,6 +29,9 @@ import ( // SimpleExtendedAttributes implements extended attributes using a map of // names to values. // +// SimpleExtendedAttributes calls vfs.CheckXattrPermissions, so callers are not +// required to do so. +// // +stateify savable type SimpleExtendedAttributes struct { // mu protects the below fields. @@ -34,7 +40,11 @@ type SimpleExtendedAttributes struct { } // GetXattr returns the value at 'name'. -func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) { +func (x *SimpleExtendedAttributes) GetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.GetXattrOptions) (string, error) { + if err := vfs.CheckXattrPermissions(creds, vfs.MayRead, mode, kuid, opts.Name); err != nil { + return "", err + } + x.mu.RLock() value, ok := x.xattrs[opts.Name] x.mu.RUnlock() @@ -50,7 +60,11 @@ func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, } // SetXattr sets 'value' at 'name'. -func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error { +func (x *SimpleExtendedAttributes) SetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.SetXattrOptions) error { + if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, opts.Name); err != nil { + return err + } + x.mu.Lock() defer x.mu.Unlock() if x.xattrs == nil { @@ -73,12 +87,19 @@ func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error { } // ListXattr returns all names in xattrs. -func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) { +func (x *SimpleExtendedAttributes) ListXattr(creds *auth.Credentials, size uint64) ([]string, error) { // Keep track of the size of the buffer needed in listxattr(2) for the list. listSize := 0 x.mu.RLock() names := make([]string, 0, len(x.xattrs)) + haveCap := creds.HasCapability(linux.CAP_SYS_ADMIN) for n := range x.xattrs { + // Hide extended attributes in the "trusted" namespace from + // non-privileged users. This is consistent with Linux's + // fs/xattr.c:simple_xattr_list(). + if !haveCap && strings.HasPrefix(n, linux.XATTR_TRUSTED_PREFIX) { + continue + } names = append(names, n) // Add one byte per null terminator. listSize += len(n) + 1 @@ -91,7 +112,11 @@ func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) { } // RemoveXattr removes the xattr at 'name'. -func (x *SimpleExtendedAttributes) RemoveXattr(name string) error { +func (x *SimpleExtendedAttributes) RemoveXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, name string) error { + if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, name); err != nil { + return err + } + x.mu.Lock() defer x.mu.Unlock() if _, ok := x.xattrs[name]; !ok { diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 82fd382c2..03857dfc8 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -220,7 +221,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr vdDentry := vd.dentry vdDentry.mu.Lock() for { - if vdDentry.dead { + if vd.mount.umounted || vdDentry.dead { vdDentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef(ctx) @@ -284,7 +285,7 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia // UmountAt removes the Mount at the given path. func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error { if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 { - return syserror.EINVAL + return linuxerr.EINVAL } // MNT_FORCE is currently unimplemented except for the permission check. @@ -301,19 +302,19 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti } defer vd.DecRef(ctx) if vd.dentry != vd.mount.root { - return syserror.EINVAL + return linuxerr.EINVAL } vfs.mountMu.Lock() if mntns := MountNamespaceFromContext(ctx); mntns != nil { defer mntns.DecRef(ctx) if mntns != vd.mount.ns { vfs.mountMu.Unlock() - return syserror.EINVAL + return linuxerr.EINVAL } if vd.mount == vd.mount.ns.root { vfs.mountMu.Unlock() - return syserror.EINVAL + return linuxerr.EINVAL } } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 87fdcf403..f2aabb905 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -42,6 +42,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -285,7 +286,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential if newpop.FollowFinalSymlink { oldVD.DecRef(ctx) ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } rp := vfs.getResolvingPath(creds, newpop) @@ -321,7 +322,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is // also honored." - mkdir(2) @@ -359,7 +360,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } rp := vfs.getResolvingPath(creds, pop) @@ -402,13 +403,13 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential // filesystem implementations that do not support it). if opts.Flags&linux.O_TMPFILE != 0 { if opts.Flags&linux.O_DIRECTORY == 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } // O_PATH causes most other flags to be ignored. @@ -499,7 +500,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti } if oldpop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } oldParentVD, oldName, err := vfs.getParentDirAndName(ctx, creds, oldpop) @@ -521,7 +522,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti if newpop.FollowFinalSymlink { oldParentVD.DecRef(ctx) ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } rp := vfs.getResolvingPath(creds, newpop) @@ -561,7 +562,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } rp := vfs.getResolvingPath(creds, pop) @@ -644,7 +645,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } rp := vfs.getResolvingPath(creds, pop) @@ -678,7 +679,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti } if pop.FollowFinalSymlink { ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink") - return syserror.EINVAL + return linuxerr.EINVAL } rp := vfs.getResolvingPath(creds, pop) @@ -731,8 +732,8 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede rp.Release(ctx) return names, nil } - if err == syserror.ENOTSUP { - // Linux doesn't actually return ENOTSUP in this case; instead, + if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) { + // Linux doesn't actually return EOPNOTSUPP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. @@ -830,14 +831,14 @@ func (vfs *VirtualFilesystem) MkdirAllAt(ctx context.Context, currentPath string Path: fspath.Parse(currentPath), } stat, err := vfs.StatAt(ctx, creds, pop, &StatOptions{Mask: linux.STATX_TYPE}) - switch err { - case nil: + switch { + case err == nil: if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.FileTypeMask != linux.ModeDirectory { return syserror.ENOTDIR } // Directory already exists. return nil - case syserror.ENOENT: + case linuxerr.Equals(linuxerr.ENOENT, err): // Expected, we will create the dir. default: return fmt.Errorf("stat failed for %q during directory creation: %w", currentPath, err) @@ -871,7 +872,7 @@ func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, targe Root: root, Start: root, Path: fspath.Parse(target), - }, mkdirOpts); err != nil && err != syserror.EEXIST { + }, mkdirOpts); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) { return fmt.Errorf("failed to create mountpoint %q: %w", target, err) } return nil diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index dfe85f31d..e8f7d1f01 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -77,11 +77,6 @@ var DefaultOpts = Opts{ // trigger it. const descheduleThreshold = 1 * time.Second -var ( - stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout") - stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected") -) - // Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck. var stackDumpSameTaskPeriod = time.Minute @@ -115,14 +110,14 @@ func (a *Action) Get() interface{} { } // String returns Action's string representation. -func (a *Action) String() string { - switch *a { +func (a Action) String() string { + switch a { case LogWarning: return "logWarning" case Panic: return "panic" default: - panic(fmt.Sprintf("Invalid watchdog action: %d", *a)) + panic(fmt.Sprintf("Invalid watchdog action: %d", a)) } } @@ -242,7 +237,6 @@ func (w *Watchdog) waitForStart() { return } - stuckStartup.Increment() metric.WeirdnessMetric.Increment("watchdog_stuck_startup") var buf bytes.Buffer @@ -316,7 +310,6 @@ func (w *Watchdog) runTurn() { // unless they are surrounded by // Task.UninterruptibleSleepStart/Finish. tc = &offender{lastUpdateTime: lastUpdateTime} - stuckTasks.Increment() metric.WeirdnessMetric.Increment("watchdog_stuck_tasks") newTaskFound = true } diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD index fd6127b97..367765209 100644 --- a/pkg/shim/BUILD +++ b/pkg/shim/BUILD @@ -6,6 +6,7 @@ go_library( name = "shim", srcs = [ "api.go", + "debug.go", "epoll.go", "options.go", "service.go", diff --git a/pkg/shim/debug.go b/pkg/shim/debug.go new file mode 100644 index 000000000..49f01990e --- /dev/null +++ b/pkg/shim/debug.go @@ -0,0 +1,48 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "os" + "os/signal" + "runtime" + "sync" + "syscall" + + "github.com/containerd/containerd/log" +) + +var once sync.Once + +func setDebugSigHandler() { + once.Do(func() { + dumpCh := make(chan os.Signal, 1) + signal.Notify(dumpCh, syscall.SIGUSR2) + go func() { + buf := make([]byte, 10240) + for range dumpCh { + for { + n := runtime.Stack(buf, true) + if n >= len(buf) { + buf = make([]byte, 2*len(buf)) + continue + } + log.L.Debugf("User requested stack trace:\n%s", buf[:n]) + } + } + }() + log.L.Debugf("For full process dump run: kill -%d %d", syscall.SIGUSR2, os.Getpid()) + }) +} diff --git a/pkg/shim/proc/BUILD b/pkg/shim/proc/BUILD index 544bdc170..c8527a6d9 100644 --- a/pkg/shim/proc/BUILD +++ b/pkg/shim/proc/BUILD @@ -20,7 +20,9 @@ go_library( "//shim:__subpackages__", ], deps = [ + "//pkg/cleanup", "//pkg/shim/runsc", + "//pkg/shim/utils", "@com_github_containerd_console//:go_default_library", "@com_github_containerd_containerd//errdefs:go_default_library", "@com_github_containerd_containerd//log:go_default_library", diff --git a/pkg/shim/proc/deleted_state.go b/pkg/shim/proc/deleted_state.go index d9b970c4d..b0bbe4d7e 100644 --- a/pkg/shim/proc/deleted_state.go +++ b/pkg/shim/proc/deleted_state.go @@ -22,28 +22,38 @@ import ( "github.com/containerd/console" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/pkg/process" + runc "github.com/containerd/go-runc" ) type deletedState struct{} -func (*deletedState) Resize(ws console.WinSize) error { - return fmt.Errorf("cannot resize a deleted process.ss") +func (*deletedState) Resize(console.WinSize) error { + return fmt.Errorf("cannot resize a deleted container/process") } -func (*deletedState) Start(ctx context.Context) error { - return fmt.Errorf("cannot start a deleted process.ss") +func (*deletedState) Start(context.Context) error { + return fmt.Errorf("cannot start a deleted container/process") } -func (*deletedState) Delete(ctx context.Context) error { - return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound) +func (*deletedState) Delete(context.Context) error { + return fmt.Errorf("cannot delete a deleted container/process: %w", errdefs.ErrNotFound) } -func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error { - return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound) +func (*deletedState) Kill(_ context.Context, signal uint32, _ bool) error { + return handleStoppedKill(signal) } -func (*deletedState) SetExited(status int) {} +func (*deletedState) SetExited(int) {} -func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { +func (*deletedState) Exec(context.Context, string, *ExecConfig) (process.Process, error) { return nil, fmt.Errorf("cannot exec in a deleted state") } + +func (s *deletedState) State(context.Context) (string, error) { + // There is no "deleted" state, closest one is stopped. + return "stopped", nil +} + +func (s *deletedState) Stats(context.Context, string) (*runc.Stats, error) { + return nil, fmt.Errorf("cannot stat a stopped container/process") +} diff --git a/pkg/shim/proc/exec.go b/pkg/shim/proc/exec.go index e7968d9d5..e0f2ae6fa 100644 --- a/pkg/shim/proc/exec.go +++ b/pkg/shim/proc/exec.go @@ -26,11 +26,13 @@ import ( "github.com/containerd/console" "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" "github.com/containerd/containerd/pkg/stdio" "github.com/containerd/fifo" runc "github.com/containerd/go-runc" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/shim/runsc" ) @@ -92,6 +94,12 @@ func (e *execProcess) SetExited(status int) { } func (e *execProcess) setExited(status int) { + if !e.exited.IsZero() { + log.L.Debugf("Exec: status already set to %d, ignoring status: %d", e.status, status) + return + } + + log.L.Debugf("Exec: setting status: %d", status) e.status = status e.exited = time.Now() e.parent.Platform.ShutdownConsole(context.Background(), e.console) @@ -105,7 +113,7 @@ func (e *execProcess) Delete(ctx context.Context) error { return e.execState.Delete(ctx) } -func (e *execProcess) delete(ctx context.Context) error { +func (e *execProcess) delete() error { e.wg.Wait() if e.io != nil { for _, c := range e.closers { @@ -113,12 +121,6 @@ func (e *execProcess) delete(ctx context.Context) error { } e.io.Close() } - pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) - // silently ignore error - os.Remove(pidfile) - internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) - // silently ignore error - os.Remove(internalPidfile) return nil } @@ -145,16 +147,13 @@ func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error { func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error { internalPid := e.internalPid - if internalPid != 0 { - if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{ - Pid: internalPid, - }); err != nil { - // If this returns error, consider the process has - // already stopped. - // - // TODO: Fix after signal handling is fixed. - return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) - } + if internalPid == 0 { + return nil + } + + opts := runsc.KillOpts{Pid: internalPid} + if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &opts); err != nil { + return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) } return nil } @@ -174,42 +173,53 @@ func (e *execProcess) Start(ctx context.Context) error { return e.execState.Start(ctx) } -func (e *execProcess) start(ctx context.Context) (err error) { - var ( - socket *runc.Socket - pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) - internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) - ) - if e.stdio.Terminal { - if socket, err = runc.NewTempConsoleSocket(); err != nil { +func (e *execProcess) start(ctx context.Context) error { + var socket *runc.Socket + + switch { + case e.stdio.Terminal: + s, err := runc.NewTempConsoleSocket() + if err != nil { return fmt.Errorf("failed to create runc console socket: %w", err) } - defer socket.Close() - } else if e.stdio.IsNull() { - if e.io, err = runc.NewNullIO(); err != nil { + defer s.Close() + socket = s + + case e.stdio.IsNull(): + io, err := runc.NewNullIO() + if err != nil { return fmt.Errorf("creating new NULL IO: %w", err) } - } else { - if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil { + e.io = io + + default: + io, err := runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)) + if err != nil { return fmt.Errorf("failed to create runc io pipes: %w", err) } + e.io = io } + opts := &runsc.ExecOpts{ - PidFile: pidfile, - InternalPidFile: internalPidfile, + PidFile: filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)), + InternalPidFile: filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)), IO: e.io, Detach: true, } + defer func() { + _ = os.Remove(opts.PidFile) + _ = os.Remove(opts.InternalPidFile) + }() if socket != nil { opts.ConsoleSocket = socket } + eventCh := e.parent.Monitor.Subscribe() - defer func() { - // Unsubscribe if an error is returned. - if err != nil { - e.parent.Monitor.Unsubscribe(eventCh) - } - }() + cu := cleanup.Make(func() { + e.parent.Monitor.Unsubscribe(eventCh) + }) + defer cu.Clean() + if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil { close(e.waitBlock) return e.parent.runtimeError(err, "OCI runtime exec failed") @@ -237,6 +247,7 @@ func (e *execProcess) start(ctx context.Context) (err error) { return fmt.Errorf("failed to start io pipe copy: %w", err) } } + pid, err := runc.ReadPidFile(opts.PidFile) if err != nil { return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err) @@ -247,6 +258,7 @@ func (e *execProcess) start(ctx context.Context) (err error) { return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err) } e.internalPid = internalPid + go func() { defer e.parent.Monitor.Unsubscribe(eventCh) for event := range eventCh { @@ -260,21 +272,25 @@ func (e *execProcess) start(ctx context.Context) (err error) { } } }() + + cu.Release() // cancel cleanup on success. return nil } -func (e *execProcess) Status(ctx context.Context) (string, error) { +func (e *execProcess) Status(context.Context) (string, error) { e.mu.Lock() defer e.mu.Unlock() // if we don't have a pid then the exec process has just been created if e.pid == 0 { return "created", nil } - // if we have a pid and it can be signaled, the process is running - // TODO(random-liu): Use `runsc kill --pid`. - if err := unix.Kill(e.pid, 0); err == nil { - return "running", nil + // This checks that `runsc exec` process is still running. This process has + // the same lifetime as the process executing inside the container. So instead + // of calling `runsc kill --pid`, just do a quick check that `runsc exec` is + // still running. + if err := unix.Kill(e.pid, 0); err != nil { + // Can't signal the process, it must have exited. + return "stopped", nil } - // else if we have a pid but it can nolonger be signaled, it has stopped - return "stopped", nil + return "running", nil } diff --git a/pkg/shim/proc/exec_state.go b/pkg/shim/proc/exec_state.go index 4dcda8b44..8d8ecf541 100644 --- a/pkg/shim/proc/exec_state.go +++ b/pkg/shim/proc/exec_state.go @@ -34,18 +34,21 @@ type execCreatedState struct { p *execProcess } -func (s *execCreatedState) transition(name string) error { - switch name { - case "running": +func (s *execCreatedState) name() string { + return "created" +} + +func (s *execCreatedState) transition(transition stateTransition) { + switch transition { + case running: s.p.execState = &execRunningState{p: s.p} - case "stopped": + case stopped: s.p.execState = &execStoppedState{p: s.p} - case "deleted": + case deleted: s.p.execState = &deletedState{} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil } func (s *execCreatedState) Resize(ws console.WinSize) error { @@ -56,14 +59,16 @@ func (s *execCreatedState) Start(ctx context.Context) error { if err := s.p.start(ctx); err != nil { return err } - return s.transition("running") + s.transition(running) + return nil } -func (s *execCreatedState) Delete(ctx context.Context) error { - if err := s.p.delete(ctx); err != nil { +func (s *execCreatedState) Delete(context.Context) error { + if err := s.p.delete(); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error { @@ -72,35 +77,35 @@ func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error func (s *execCreatedState) SetExited(status int) { s.p.setExited(status) - - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } type execRunningState struct { p *execProcess } -func (s *execRunningState) transition(name string) error { - switch name { - case "stopped": +func (s *execRunningState) name() string { + return "running" +} + +func (s *execRunningState) transition(transition stateTransition) { + switch transition { + case stopped: s.p.execState = &execStoppedState{p: s.p} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil } func (s *execRunningState) Resize(ws console.WinSize) error { return s.p.resize(ws) } -func (s *execRunningState) Start(ctx context.Context) error { +func (s *execRunningState) Start(context.Context) error { return fmt.Errorf("cannot start a running process") } -func (s *execRunningState) Delete(ctx context.Context) error { +func (s *execRunningState) Delete(context.Context) error { return fmt.Errorf("cannot delete a running process") } @@ -110,45 +115,46 @@ func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error func (s *execRunningState) SetExited(status int) { s.p.setExited(status) - - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } type execStoppedState struct { p *execProcess } -func (s *execStoppedState) transition(name string) error { - switch name { - case "deleted": +func (s *execStoppedState) name() string { + return "stopped" +} + +func (s *execStoppedState) transition(transition stateTransition) { + switch transition { + case deleted: s.p.execState = &deletedState{} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil } -func (s *execStoppedState) Resize(ws console.WinSize) error { +func (s *execStoppedState) Resize(console.WinSize) error { return fmt.Errorf("cannot resize a stopped container") } -func (s *execStoppedState) Start(ctx context.Context) error { +func (s *execStoppedState) Start(context.Context) error { return fmt.Errorf("cannot start a stopped process") } -func (s *execStoppedState) Delete(ctx context.Context) error { - if err := s.p.delete(ctx); err != nil { +func (s *execStoppedState) Delete(context.Context) error { + if err := s.p.delete(); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } -func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error { - return s.p.kill(ctx, sig, all) +func (s *execStoppedState) Kill(_ context.Context, sig uint32, _ bool) error { + return handleStoppedKill(sig) } -func (s *execStoppedState) SetExited(status int) { +func (s *execStoppedState) SetExited(int) { // no op } diff --git a/pkg/shim/proc/init.go b/pkg/shim/proc/init.go index 664465e0d..6bf090813 100644 --- a/pkg/shim/proc/init.go +++ b/pkg/shim/proc/init.go @@ -39,6 +39,8 @@ import ( "gvisor.dev/gvisor/pkg/shim/runsc" ) +const statusStopped = "stopped" + // Init represents an initial process for a container. type Init struct { wg sync.WaitGroup @@ -201,10 +203,15 @@ func (p *Init) ExitedAt() time.Time { func (p *Init) Status(ctx context.Context) (string, error) { p.mu.Lock() defer p.mu.Unlock() + + return p.initState.State(ctx) +} + +func (p *Init) state(ctx context.Context) (string, error) { c, err := p.runtime.State(ctx, p.id) if err != nil { if strings.Contains(err.Error(), "does not exist") { - return "stopped", nil + return statusStopped, nil } return "", p.runtimeError(err, "OCI runtime state failed") } @@ -231,10 +238,7 @@ func (p *Init) start(ctx context.Context) error { status, err := p.runtime.Wait(context.Background(), p.id) if err != nil { log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id) - // TODO(random-liu): Handle runsc kill error. - if err := p.killAll(ctx); err != nil { - log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id) - } + p.killAllLocked(ctx) status = internalErrorCode } ExitCh <- Exit{ @@ -255,6 +259,12 @@ func (p *Init) SetExited(status int) { } func (p *Init) setExited(status int) { + if !p.exited.IsZero() { + log.L.Debugf("Status already set to %d, ignoring status: %d", p.status, status) + return + } + + log.L.Debugf("Setting status: %d", status) p.exited = time.Now() p.status = status p.Platform.ShutdownConsole(context.Background(), p.console) @@ -270,15 +280,16 @@ func (p *Init) Delete(ctx context.Context) error { } func (p *Init) delete(ctx context.Context) error { - p.killAll(ctx) + p.killAllLocked(ctx) p.wg.Wait() + err := p.runtime.Delete(ctx, p.id, nil) - // ignore errors if a runtime has already deleted the process - // but we still hold metadata and pipes - // - // this is common during a checkpoint, runc will delete the container state - // after a checkpoint and the container will no longer exist within runc if err != nil { + // ignore errors if a runtime has already deleted the process + // but we still hold metadata and pipes + // + // this is common during a checkpoint, runc will delete the container state + // after a checkpoint and the container will no longer exist within runc if strings.Contains(err.Error(), "does not exist") { err = nil } else { @@ -326,29 +337,24 @@ func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error { return p.initState.Kill(ctx, signal, all) } -func (p *Init) kill(context context.Context, signal uint32, all bool) error { +func (p *Init) kill(ctx context.Context, signal uint32, all bool) error { var ( killErr error backoff = 100 * time.Millisecond ) - timeout := 1 * time.Second - for start := time.Now(); time.Now().Sub(start) < timeout; { - c, err := p.runtime.State(context, p.id) + const timeout = time.Second + for start := time.Now(); time.Since(start) < timeout; { + state, err := p.initState.State(ctx) if err != nil { - if strings.Contains(err.Error(), "does not exist") { - return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) - } return p.runtimeError(err, "OCI runtime state failed") } // For runsc, signal only works when container is running state. // If the container is not in running state, directly return // "no such process" - if p.convertStatus(c.Status) == "stopped" { + if state == statusStopped { return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) } - killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{ - All: all, - }) + killErr = p.runtime.Kill(ctx, p.id, int(signal), &runsc.KillOpts{All: all}) if killErr == nil { return nil } @@ -358,22 +364,18 @@ func (p *Init) kill(context context.Context, signal uint32, all bool) error { return p.runtimeError(killErr, "kill timeout") } -// KillAll kills all processes belonging to the init process. -func (p *Init) KillAll(context context.Context) error { +// KillAll kills all processes belonging to the init process. If +// `runsc kill --all` returns error, assume the container has already stopped. +func (p *Init) KillAll(context context.Context) { p.mu.Lock() defer p.mu.Unlock() - return p.killAll(context) + p.killAllLocked(context) } -func (p *Init) killAll(context context.Context) error { - p.runtime.Kill(context, p.id, int(unix.SIGKILL), &runsc.KillOpts{ - All: true, - }) - // Ignore error handling for `runsc kill --all` for now. - // * If it doesn't return error, it is good; - // * If it returns error, consider the container has already stopped. - // TODO: Fix `runsc kill --all` error handling. - return nil +func (p *Init) killAllLocked(context context.Context) { + if err := p.runtime.Kill(context, p.id, int(unix.SIGKILL), &runsc.KillOpts{All: true}); err != nil { + log.L.Warningf("Ignoring error killing container %q: %v", p.id, err) + } } // Stdin returns the stdin of the process. @@ -396,7 +398,6 @@ func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Pr // exec returns a new exec'd process. func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) { - // process exec request var spec specs.Process if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { return nil, err @@ -420,6 +421,17 @@ func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) { return e, nil } +func (p *Init) Stats(ctx context.Context, id string) (*runc.Stats, error) { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Stats(ctx, id) +} + +func (p *Init) stats(ctx context.Context, id string) (*runc.Stats, error) { + return p.Runtime().Stats(ctx, id) +} + // Stdio returns the stdio of the process. func (p *Init) Stdio() stdio.Stdio { return p.stdio @@ -444,7 +456,7 @@ func (p *Init) runtimeError(rErr error, msg string) error { func (p *Init) convertStatus(status string) string { if status == "created" && !p.Sandbox && p.status == internalErrorCode { // Treat start failure state for non-root container as stopped. - return "stopped" + return statusStopped } return status } diff --git a/pkg/shim/proc/init_state.go b/pkg/shim/proc/init_state.go index 0065fc385..5347ddefe 100644 --- a/pkg/shim/proc/init_state.go +++ b/pkg/shim/proc/init_state.go @@ -19,16 +19,40 @@ import ( "context" "fmt" - "github.com/containerd/console" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/pkg/process" + runc "github.com/containerd/go-runc" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/shim/utils" ) +type stateTransition int + +const ( + running stateTransition = iota + stopped + deleted +) + +func (s stateTransition) String() string { + switch s { + case running: + return "running" + case stopped: + return "stopped" + case deleted: + return "deleted" + default: + panic(fmt.Sprintf("unknown state: %d", s)) + } +} + type initState interface { - Resize(console.WinSize) error Start(context.Context) error Delete(context.Context) error Exec(context.Context, string, *ExecConfig) (process.Process, error) + State(ctx context.Context) (string, error) + Stats(context.Context, string) (*runc.Stats, error) Kill(context.Context, uint32, bool) error SetExited(int) } @@ -37,22 +61,21 @@ type createdState struct { p *Init } -func (s *createdState) transition(name string) error { - switch name { - case "running": +func (s *createdState) name() string { + return "created" +} + +func (s *createdState) transition(transition stateTransition) { + switch transition { + case running: s.p.initState = &runningState{p: s.p} - case "stopped": - s.p.initState = &stoppedState{p: s.p} - case "deleted": + case stopped: + s.p.initState = &stoppedState{process: s.p} + case deleted: s.p.initState = &deletedState{} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil -} - -func (s *createdState) Resize(ws console.WinSize) error { - return s.p.resize(ws) } func (s *createdState) Start(ctx context.Context) error { @@ -66,20 +89,20 @@ func (s *createdState) Start(ctx context.Context) error { if !s.p.Sandbox { s.p.io.Close() s.p.setExited(internalErrorCode) - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } return err } - return s.transition("running") + s.transition(running) + return nil } func (s *createdState) Delete(ctx context.Context) error { if err := s.p.delete(ctx); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { @@ -88,40 +111,48 @@ func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { func (s *createdState) SetExited(status int) { s.p.setExited(status) - - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { return s.p.exec(path, r) } +func (s *createdState) State(ctx context.Context) (string, error) { + state, err := s.p.state(ctx) + if err == nil && state == statusStopped { + s.transition(stopped) + } + return state, err +} + +func (s *createdState) Stats(ctx context.Context, id string) (*runc.Stats, error) { + return s.p.stats(ctx, id) +} + type runningState struct { p *Init } -func (s *runningState) transition(name string) error { - switch name { - case "stopped": - s.p.initState = &stoppedState{p: s.p} - default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) - } - return nil +func (s *runningState) name() string { + return "running" } -func (s *runningState) Resize(ws console.WinSize) error { - return s.p.resize(ws) +func (s *runningState) transition(transition stateTransition) { + switch transition { + case stopped: + s.p.initState = &stoppedState{process: s.p} + default: + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) + } } func (s *runningState) Start(ctx context.Context) error { - return fmt.Errorf("cannot start a running process.ss") + return fmt.Errorf("cannot start a running container") } func (s *runningState) Delete(ctx context.Context) error { - return fmt.Errorf("cannot delete a running process.ss") + return fmt.Errorf("cannot delete a running container") } func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { @@ -130,53 +161,81 @@ func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { func (s *runningState) SetExited(status int) { s.p.setExited(status) + s.transition(stopped) +} + +func (s *runningState) Exec(_ context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(path, r) +} - if err := s.transition("stopped"); err != nil { - panic(err) +func (s *runningState) State(ctx context.Context) (string, error) { + state, err := s.p.state(ctx) + if err == nil && state == "stopped" { + s.transition(stopped) } + return state, err } -func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { - return s.p.exec(path, r) +func (s *runningState) Stats(ctx context.Context, id string) (*runc.Stats, error) { + return s.p.stats(ctx, id) } type stoppedState struct { - p *Init + process *Init } -func (s *stoppedState) transition(name string) error { - switch name { - case "deleted": - s.p.initState = &deletedState{} - default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) - } - return nil +func (s *stoppedState) name() string { + return "stopped" } -func (s *stoppedState) Resize(ws console.WinSize) error { - return fmt.Errorf("cannot resize a stopped container") +func (s *stoppedState) transition(transition stateTransition) { + switch transition { + case deleted: + s.process.initState = &deletedState{} + default: + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) + } } -func (s *stoppedState) Start(ctx context.Context) error { - return fmt.Errorf("cannot start a stopped process.ss") +func (s *stoppedState) Start(context.Context) error { + return fmt.Errorf("cannot start a stopped container") } func (s *stoppedState) Delete(ctx context.Context) error { - if err := s.p.delete(ctx); err != nil { + if err := s.process.delete(ctx); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } -func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error { - return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id) +func (s *stoppedState) Kill(_ context.Context, signal uint32, _ bool) error { + return handleStoppedKill(signal) } func (s *stoppedState) SetExited(status int) { - // no op + s.process.setExited(status) } -func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { +func (s *stoppedState) Exec(context.Context, string, *ExecConfig) (process.Process, error) { return nil, fmt.Errorf("cannot exec in a stopped state") } + +func (s *stoppedState) State(context.Context) (string, error) { + return "stopped", nil +} + +func (s *stoppedState) Stats(context.Context, string) (*runc.Stats, error) { + return nil, fmt.Errorf("cannot stat a stopped container") +} + +func handleStoppedKill(signal uint32) error { + switch unix.Signal(signal) { + case unix.SIGTERM, unix.SIGKILL: + // Container is already stopped, so everything inside the container has + // already been killed. + return nil + default: + return utils.ErrToGRPCf(errdefs.ErrNotFound, "process not found") + } +} diff --git a/pkg/shim/proc/proc.go b/pkg/shim/proc/proc.go index edba3fca5..89ad3f505 100644 --- a/pkg/shim/proc/proc.go +++ b/pkg/shim/proc/proc.go @@ -17,23 +17,5 @@ // the sandbox process running the container. package proc -import ( - "fmt" -) - // RunscRoot is the path to the root runsc state directory. const RunscRoot = "/run/containerd/runsc" - -func stateName(v interface{}) string { - switch v.(type) { - case *runningState, *execRunningState: - return "running" - case *createdState, *execCreatedState: - return "created" - case *deletedState: - return "deleted" - case *stoppedState: - return "stopped" - } - panic(fmt.Errorf("invalid state %v", v)) -} diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go index ff0521d73..888cb0bcb 100644 --- a/pkg/shim/runsc/runsc.go +++ b/pkg/shim/runsc/runsc.go @@ -17,6 +17,7 @@ package runsc import ( + "bytes" "context" "encoding/json" "fmt" @@ -73,9 +74,9 @@ type Runsc struct { // List returns all containers created inside the provided runsc root directory. func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { - data, err := cmdOutput(r.command(context, "list", "--format=json"), false) + data, stderr, err := cmdOutput(r.command(context, "list", "--format=json"), false) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %s", err, stderr) } var out []*runc.Container if err := json.Unmarshal(data, &out); err != nil { @@ -86,9 +87,9 @@ func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { // State returns the state for the container provided by id. func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) { - data, err := cmdOutput(r.command(context, "state", id), true) + data, stderr, err := cmdOutput(r.command(context, "state", id), false) if err != nil { - return nil, fmt.Errorf("%s: %s", err, data) + return nil, fmt.Errorf("%w: %s", err, stderr) } var c runc.Container if err := json.Unmarshal(data, &c); err != nil { @@ -142,9 +143,9 @@ func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateO } if cmd.Stdout == nil && cmd.Stderr == nil { - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -168,15 +169,15 @@ func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateO } func (r *Runsc) Pause(context context.Context, id string) error { - if _, err := cmdOutput(r.command(context, "pause", id), true); err != nil { - return fmt.Errorf("unable to pause: %w", err) + if out, _, err := cmdOutput(r.command(context, "pause", id), true); err != nil { + return fmt.Errorf("unable to pause: %w: %s", err, out) } return nil } func (r *Runsc) Resume(context context.Context, id string) error { - if _, err := cmdOutput(r.command(context, "resume", id), true); err != nil { - return fmt.Errorf("unable to resume: %w", err) + if out, _, err := cmdOutput(r.command(context, "resume", id), true); err != nil { + return fmt.Errorf("unable to resume: %w: %s", err, out) } return nil } @@ -189,9 +190,9 @@ func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error { } if cmd.Stdout == nil && cmd.Stderr == nil { - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -221,12 +222,10 @@ type waitResult struct { } // Wait will wait for a running container, and return its exit status. -// -// TODO(random-liu): Add exec process support. func (r *Runsc) Wait(context context.Context, id string) (int, error) { - data, err := cmdOutput(r.command(context, "wait", id), true) + data, stderr, err := cmdOutput(r.command(context, "wait", id), false) if err != nil { - return 0, fmt.Errorf("%s: %s", err, data) + return 0, fmt.Errorf("%w: %s", err, stderr) } var res waitResult if err := json.Unmarshal(data, &res); err != nil { @@ -294,9 +293,9 @@ func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opt opts.Set(cmd) } if cmd.Stdout == nil && cmd.Stderr == nil { - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -391,20 +390,12 @@ func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts // Stats return the stats for a container like cpu, memory, and I/O. func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { cmd := r.command(context, "events", "--stats", id) - rd, err := cmd.StdoutPipe() - if err != nil { - return nil, err - } - ec, err := Monitor.Start(cmd) + data, stderr, err := cmdOutput(cmd, false) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %s", err, stderr) } - defer func() { - rd.Close() - Monitor.Wait(cmd, ec) - }() var e runc.Event - if err := json.NewDecoder(rd).Decode(&e); err != nil { + if err := json.Unmarshal(data, &e); err != nil { log.L.Debugf("Parsing events error: %v", err) return nil, err } @@ -459,9 +450,9 @@ func (r *Runsc) Events(context context.Context, id string, interval time.Duratio // Ps lists all the processes inside the container returning their pids. func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { - data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true) + data, stderr, err := cmdOutput(r.command(context, "ps", "--format", "json", id), false) if err != nil { - return nil, fmt.Errorf("%s: %s", err, data) + return nil, fmt.Errorf("%w: %s", err, stderr) } var pids []int if err := json.Unmarshal(data, &pids); err != nil { @@ -472,9 +463,9 @@ func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { // Top lists all the processes inside the container returning the full ps data. func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) { - data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true) + data, stderr, err := cmdOutput(r.command(context, "ps", "--format", "table", id), false) if err != nil { - return nil, fmt.Errorf("%s: %s", err, data) + return nil, fmt.Errorf("%w: %s", err, stderr) } topResults, err := runc.ParsePSOutput(data) @@ -517,9 +508,9 @@ func (r *Runsc) runOrError(cmd *exec.Cmd) error { } return err } - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -540,23 +531,29 @@ func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd { return cmd } -func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) { - b := getBuf() - defer putBuf(b) +func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, []byte, error) { + stdout := getBuf() + defer putBuf(stdout) + cmd.Stdout = stdout + cmd.Stderr = stdout - cmd.Stdout = b - if combined { - cmd.Stderr = b + var stderr *bytes.Buffer + if !combined { + stderr = getBuf() + defer putBuf(stderr) + cmd.Stderr = stderr } ec, err := Monitor.Start(cmd) if err != nil { - return nil, err + return nil, nil, err } status, err := Monitor.Wait(cmd, ec) if err == nil && status != 0 { - err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + err = fmt.Errorf("%q did not terminate sucessfully", cmd.Args[0]) } - - return b.Bytes(), err + if stderr == nil { + return stdout.Bytes(), nil, err + } + return stdout.Bytes(), stderr.Bytes(), err } diff --git a/pkg/shim/service.go b/pkg/shim/service.go index 1f9adcb65..24e3b7a82 100644 --- a/pkg/shim/service.go +++ b/pkg/shim/service.go @@ -81,8 +81,6 @@ const ( // New returns a new shim service that can be used via GRPC. func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { - log.L.Debugf("service.New, id: %s", id) - var opts shim.Opts if ctxOpts := ctx.Value(shim.OptsKey{}); ctxOpts != nil { opts = ctxOpts.(shim.Opts) @@ -304,8 +302,6 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) // Create creates a new initial process and container with the underlying OCI // runtime. func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*taskAPI.CreateTaskResponse, error) { - log.L.Debugf("Create, id: %s, bundle: %q", r.ID, r.Bundle) - s.mu.Lock() defer s.mu.Unlock() @@ -396,6 +392,9 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*ta log.L.Debugf("stdout: %s", r.Stdout) log.L.Debugf("stderr: %s", r.Stderr) log.L.Debugf("***************************") + if log.L.Logger.IsLevelEnabled(logrus.DebugLevel) { + setDebugSigHandler() + } } // Save state before any action is taken to ensure Cleanup() will have all @@ -453,10 +452,10 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*ta } process, err := newInit(r.Bundle, filepath.Join(r.Bundle, "work"), ns, s.platform, config, &s.opts, st.Rootfs) if err != nil { - return nil, errdefs.ToGRPC(err) + return nil, utils.ErrToGRPC(err) } if err := process.Create(ctx, config); err != nil { - return nil, errdefs.ToGRPC(err) + return nil, utils.ErrToGRPC(err) } // Set up OOM notification on the sandbox's cgroup. This is done on @@ -506,9 +505,6 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP if err != nil { return nil, err } - if p == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") - } if err := p.Delete(ctx); err != nil { return nil, err } @@ -534,10 +530,10 @@ func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*typ p := s.processes[r.ExecID] s.mu.Unlock() if p != nil { - return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) + return nil, utils.ErrToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) } if s.task == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } process, err := s.task.Exec(ctx, s.bundle, &proc.ExecConfig{ ID: r.ExecID, @@ -548,7 +544,7 @@ func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*typ Spec: r.Spec, }) if err != nil { - return nil, errdefs.ToGRPC(err) + return nil, utils.ErrToGRPC(err) } s.mu.Lock() s.processes[r.ExecID] = process @@ -569,7 +565,7 @@ func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (* Height: uint16(r.Height), } if err := p.Resize(ws); err != nil { - return nil, errdefs.ToGRPC(err) + return nil, utils.ErrToGRPC(err) } return empty, nil } @@ -580,10 +576,12 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. p, err := s.getProcess(r.ExecID) if err != nil { + log.L.Debugf("State failed to find process: %v", err) return nil, err } st, err := p.Status(ctx) if err != nil { + log.L.Debugf("State failed: %v", err) return nil, err } status := task.StatusUnknown @@ -596,7 +594,7 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. status = task.StatusStopped } sio := p.Stdio() - return &taskAPI.StateResponse{ + res := &taskAPI.StateResponse{ ID: p.ID(), Bundle: s.bundle, Pid: uint32(p.Pid()), @@ -607,7 +605,9 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. Terminal: sio.Terminal, ExitStatus: uint32(p.ExitStatus()), ExitedAt: p.ExitedAt(), - }, nil + } + log.L.Debugf("State succeeded, response: %+v", res) + return res, nil } // Pause the container. @@ -615,7 +615,7 @@ func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Em log.L.Debugf("Pause, id: %s", r.ID) if s.task == nil { log.L.Debugf("Pause error, id: %s: container not created", r.ID) - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } err := s.task.Runtime().Pause(ctx, r.ID) if err != nil { @@ -629,7 +629,7 @@ func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types. log.L.Debugf("Resume, id: %s", r.ID) if s.task == nil { log.L.Debugf("Resume error, id: %s: container not created", r.ID) - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } err := s.task.Runtime().Resume(ctx, r.ID) if err != nil { @@ -646,12 +646,11 @@ func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empt if err != nil { return nil, err } - if p == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") - } if err := p.Kill(ctx, r.Signal, r.All); err != nil { - return nil, errdefs.ToGRPC(err) + log.L.Debugf("Kill failed: %v", err) + return nil, utils.ErrToGRPC(err) } + log.L.Debugf("Kill succeeded") return empty, nil } @@ -661,7 +660,7 @@ func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.Pi pids, err := s.getContainerPids(ctx, r.ID) if err != nil { - return nil, errdefs.ToGRPC(err) + return nil, utils.ErrToGRPC(err) } var processes []*task.ProcessInfo for _, pid := range pids { @@ -707,7 +706,7 @@ func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*type // Checkpoint checkpoints the container. func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) { log.L.Debugf("Checkpoint, id: %s", r.ID) - return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) + return empty, utils.ErrToGRPC(errdefs.ErrNotImplemented) } // Connect returns shim information such as the shim's pid. @@ -738,9 +737,9 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. log.L.Debugf("Stats, id: %s", r.ID) if s.task == nil { log.L.Debugf("Stats error, id: %s: container not created", r.ID) - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - stats, err := s.task.Runtime().Stats(ctx, s.id) + stats, err := s.task.Stats(ctx, s.id) if err != nil { log.L.Debugf("Stats error, id: %s: %v", r.ID, err) return nil, err @@ -812,7 +811,7 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. // Update updates a running container. func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) { - return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) + return empty, utils.ErrToGRPC(errdefs.ErrNotImplemented) } // Wait waits for a process to exit. @@ -821,17 +820,17 @@ func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.Wa p, err := s.getProcess(r.ExecID) if err != nil { + log.L.Debugf("Wait failed to find process: %v", err) return nil, err } - if p == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") - } p.Wait() - return &taskAPI.WaitResponse{ + res := &taskAPI.WaitResponse{ ExitStatus: uint32(p.ExitStatus()), ExitedAt: p.ExitedAt(), - }, nil + } + log.L.Debugf("Wait succeeded, response: %+v", res) + return res, nil } func (s *service) processExits(ctx context.Context) { @@ -848,10 +847,7 @@ func (s *service) checkProcesses(ctx context.Context, e proc.Exit) { if ip, ok := p.(*proc.Init); ok { // Ensure all children are killed. log.L.Debugf("Container init process exited, killing all container processes") - if err := ip.KillAll(ctx); err != nil { - log.G(ctx).WithError(err).WithField("id", ip.ID()). - Error("failed to kill init's children") - } + ip.KillAll(ctx) } p.SetExited(e.Status) s.events <- &events.TaskExit{ @@ -909,12 +905,17 @@ func (s *service) forward(ctx context.Context, publisher shim.Publisher) { func (s *service) getProcess(execID string) (process.Process, error) { s.mu.Lock() defer s.mu.Unlock() + if execID == "" { + if s.task == nil { + return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } return s.task, nil } + p := s.processes[execID] if p == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) + return nil, utils.ErrToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) } return p, nil } diff --git a/pkg/shim/utils/BUILD b/pkg/shim/utils/BUILD index 54a0aabb7..2eb82f63c 100644 --- a/pkg/shim/utils/BUILD +++ b/pkg/shim/utils/BUILD @@ -6,6 +6,7 @@ go_library( name = "utils", srcs = [ "annotations.go", + "errors.go", "utils.go", "volumes.go", ], @@ -14,14 +15,23 @@ go_library( "//shim:__subpackages__", ], deps = [ + "@com_github_containerd_containerd//errdefs:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", ], ) go_test( name = "utils_test", size = "small", - srcs = ["volumes_test.go"], + srcs = [ + "errors_test.go", + "volumes_test.go", + ], library = ":utils", - deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"], + deps = [ + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], ) diff --git a/pkg/shim/utils/errors.go b/pkg/shim/utils/errors.go new file mode 100644 index 000000000..971d68c36 --- /dev/null +++ b/pkg/shim/utils/errors.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "context" + "errors" + "fmt" + + "github.com/containerd/containerd/errdefs" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// ErrToGRPC wraps containerd's ToGRPC error mapper which depends on +// github.com/pkg/errors to work correctly. Once we upgrade to containerd v1.4, +// this function can go away and we can use errdefs.ToGRPC directly instead. +// +// TODO(gvisor.dev/issue/6232): Remove after upgrading to containerd v1.4 +func ErrToGRPC(err error) error { + return errToGRPCMsg(err, err.Error()) +} + +// ErrToGRPCf maps the error to grpc error codes, assembling the formatting +// string and combining it with the target error string. +// +// TODO(gvisor.dev/issue/6232): Remove after upgrading to containerd v1.4 +func ErrToGRPCf(err error, format string, args ...interface{}) error { + formatted := fmt.Sprintf(format, args...) + msg := fmt.Sprintf("%s: %s", formatted, err.Error()) + return errToGRPCMsg(err, msg) +} + +func errToGRPCMsg(err error, msg string) error { + if err == nil { + return nil + } + if _, ok := status.FromError(err); ok { + return err + } + + switch { + case errors.Is(err, errdefs.ErrInvalidArgument): + return status.Errorf(codes.InvalidArgument, msg) + case errors.Is(err, errdefs.ErrNotFound): + return status.Errorf(codes.NotFound, msg) + case errors.Is(err, errdefs.ErrAlreadyExists): + return status.Errorf(codes.AlreadyExists, msg) + case errors.Is(err, errdefs.ErrFailedPrecondition): + return status.Errorf(codes.FailedPrecondition, msg) + case errors.Is(err, errdefs.ErrUnavailable): + return status.Errorf(codes.Unavailable, msg) + case errors.Is(err, errdefs.ErrNotImplemented): + return status.Errorf(codes.Unimplemented, msg) + case errors.Is(err, context.Canceled): + return status.Errorf(codes.Canceled, msg) + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, msg) + } + + return errdefs.ToGRPC(err) +} diff --git a/pkg/shim/utils/errors_test.go b/pkg/shim/utils/errors_test.go new file mode 100644 index 000000000..0a8fe34c8 --- /dev/null +++ b/pkg/shim/utils/errors_test.go @@ -0,0 +1,50 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "testing" + + "github.com/containerd/containerd/errdefs" +) + +func TestGRPCRoundTripsErrors(t *testing.T) { + for _, tc := range []struct { + name string + err error + test func(err error) bool + }{ + { + name: "passthrough", + err: errdefs.ErrNotFound, + test: errdefs.IsNotFound, + }, + { + name: "wrapped", + err: fmt.Errorf("oh no: %w", errdefs.ErrNotFound), + test: errdefs.IsNotFound, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if err := errdefs.FromGRPC(ErrToGRPC(tc.err)); !tc.test(err) { + t.Errorf("errToGRPC got %+v", err) + } + if err := errdefs.FromGRPC(ErrToGRPCf(tc.err, "testing %s", "123")); !tc.test(err) { + t.Errorf("errToGRPCf got %+v", err) + } + }) + } +} diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 8b3a11c64..73791b456 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -1,5 +1,4 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") package( default_visibility = ["//:sandbox"], @@ -8,45 +7,6 @@ package( exports_files(["LICENSE"]) -go_template( - name = "generic_atomicptr", - srcs = ["generic_atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_atomicptrmap", - srcs = ["generic_atomicptrmap_unsafe.go"], - opt_consts = [ - "ShardOrder", - ], - opt_types = [ - "Hasher", - ], - types = [ - "Key", - "Value", - ], - deps = [ - ":sync", - "//pkg/gohacks", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["generic_seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - "//pkg/gohacks", - ], -) - go_library( name = "sync", srcs = [ diff --git a/pkg/sync/README.md b/pkg/sync/README.md index 2183c4e20..be1a01f08 100644 --- a/pkg/sync/README.md +++ b/pkg/sync/README.md @@ -1,4 +1,4 @@ -# Syncutil +# sync This package provides additional synchronization primitives not provided by the Go stdlib 'sync' package. It is partially derived from the upstream 'sync' diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptr/BUILD index e97553254..a6a7f01ac 100644 --- a/pkg/sync/atomicptrtest/BUILD +++ b/pkg/sync/atomicptr/BUILD @@ -1,14 +1,23 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +go_template( + name = "generic_atomicptr", + srcs = ["generic_atomicptr_unsafe.go"], + types = [ + "Value", + ], + visibility = ["//:sandbox"], +) + go_template_instance( name = "atomicptr_int", out = "atomicptr_int_unsafe.go", package = "atomicptr", suffix = "Int", - template = "//pkg/sync:generic_atomicptr", + template = ":generic_atomicptr", types = { "Value": "int", }, diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptr/atomicptr_test.go index 8fdc5112e..8fdc5112e 100644 --- a/pkg/sync/atomicptrtest/atomicptr_test.go +++ b/pkg/sync/atomicptr/atomicptr_test.go diff --git a/pkg/sync/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go index 82b6df18c..82b6df18c 100644 --- a/pkg/sync/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmap/BUILD index 3f71ae97d..b0e218c79 100644 --- a/pkg/sync/atomicptrmaptest/BUILD +++ b/pkg/sync/atomicptrmap/BUILD @@ -1,17 +1,36 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package( default_visibility = ["//visibility:private"], licenses = ["notice"], ) +go_template( + name = "generic_atomicptrmap", + srcs = ["generic_atomicptrmap_unsafe.go"], + opt_consts = [ + "ShardOrder", + ], + opt_types = [ + "Hasher", + ], + types = [ + "Key", + "Value", + ], + deps = [ + "//pkg/gohacks", + "//pkg/sync", + ], +) + go_template_instance( name = "test_atomicptrmap", out = "test_atomicptrmap_unsafe.go", package = "atomicptrmap", prefix = "test", - template = "//pkg/sync:generic_atomicptrmap", + template = ":generic_atomicptrmap", types = { "Key": "int64", "Value": "testValue", @@ -27,7 +46,7 @@ go_template_instance( package = "atomicptrmap", prefix = "test", suffix = "Sharded", - template = "//pkg/sync:generic_atomicptrmap", + template = ":generic_atomicptrmap", types = { "Key": "int64", "Value": "testValue", diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap.go b/pkg/sync/atomicptrmap/atomicptrmap.go index 867821ce9..867821ce9 100644 --- a/pkg/sync/atomicptrmaptest/atomicptrmap.go +++ b/pkg/sync/atomicptrmap/atomicptrmap.go diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmap/atomicptrmap_test.go index 75a9997ef..75a9997ef 100644 --- a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go +++ b/pkg/sync/atomicptrmap/atomicptrmap_test.go diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go index 3e98cb309..3e98cb309 100644 --- a/pkg/sync/generic_atomicptrmap_unsafe.go +++ b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomic/BUILD index 5f9164117..60f79ab54 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomic/BUILD @@ -1,14 +1,27 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +go_template( + name = "generic_seqatomic", + srcs = ["generic_seqatomic_unsafe.go"], + types = [ + "Value", + ], + visibility = ["//:sandbox"], + deps = [ + ":sync", + "//pkg/gohacks", + ], +) + go_template_instance( name = "seqatomic_int", out = "seqatomic_int_unsafe.go", package = "seqatomic", suffix = "Int", - template = "//pkg/sync:generic_seqatomic", + template = ":generic_seqatomic", types = { "Value": "int", }, diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go index 9578c9c52..9578c9c52 100644 --- a/pkg/sync/generic_seqatomic_unsafe.go +++ b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomic/seqatomic_test.go index 2c4568b07..2c4568b07 100644 --- a/pkg/sync/seqatomictest/seqatomic_test.go +++ b/pkg/sync/seqatomic/seqatomic_test.go diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index 9cc9e3bf2..ceee494fc 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -11,7 +11,9 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/abi/linux", + "//pkg/abi/linux/errno", + "//pkg/errors", + "//pkg/errors/linuxerr", "//pkg/syserror", "//pkg/tcpip", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 90be24e15..eb44f1254 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -17,7 +17,7 @@ package syserr import ( "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -25,33 +25,33 @@ import ( // Mapping for tcpip.Error types. var ( - ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), linux.EINVAL) - ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), linux.ENODEV) - ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), linux.ENODEV) - ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), linux.ENOPROTOOPT) - ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), linux.EEXIST) - ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), linux.EEXIST) - ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), linux.EINVAL) - ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), linux.EINVAL) - ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), linux.EALREADY) - ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), linux.EAGAIN) - ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE) - ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL) - ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE) - ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), linux.NOERRNO) - ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT) - ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE) - ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS) - ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), linux.EDESTADDRREQ) - ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), linux.EOPNOTSUPP) - ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), linux.ENOTTY) - ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), linux.ENOENT) - ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), linux.EINVAL) - ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), linux.EACCES) - ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM) - ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT) - ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL) - ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL) + ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), errno.EINVAL) + ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), errno.ENODEV) + ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), errno.ENODEV) + ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), errno.ENOPROTOOPT) + ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), errno.EEXIST) + ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), errno.EEXIST) + ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), errno.EINVAL) + ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), errno.EINVAL) + ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), errno.EALREADY) + ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), errno.EAGAIN) + ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), errno.EADDRINUSE) + ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), errno.EADDRNOTAVAIL) + ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), errno.EPIPE) + ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), errno.NOERRNO) + ErrTimeout = New((&tcpip.ErrTimeout{}).String(), errno.ETIMEDOUT) + ErrAborted = New((&tcpip.ErrAborted{}).String(), errno.EPIPE) + ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), errno.EINPROGRESS) + ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), errno.EDESTADDRREQ) + ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), errno.EOPNOTSUPP) + ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), errno.ENOTTY) + ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), errno.ENOENT) + ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), errno.EINVAL) + ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), errno.EACCES) + ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), errno.EPERM) + ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), errno.EFAULT) + ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), errno.EINVAL) + ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), errno.EINVAL) ) // TranslateNetstackError converts an error from the tcpip package to a sentry diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index d70521f32..765128ff0 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -21,7 +21,9 @@ import ( "fmt" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -31,34 +33,33 @@ type Error struct { message string // noTranslation indicates that this Error cannot be translated to a - // linux.Errno. + // errno.Errno. noTranslation bool - // errno is the linux.Errno this Error should be translated to. - errno linux.Errno + // errno is the errno.Errno this Error should be translated to. + errno errno.Errno } // New creates a new Error and adds a translation for it. // // New must only be called at init. -func New(message string, linuxTranslation linux.Errno) *Error { +func New(message string, linuxTranslation errno.Errno) *Error { err := &Error{message: message, errno: linuxTranslation} // TODO(b/34162363): Remove this. - errno := linuxTranslation - if errno < 0 || int(errno) >= len(linuxBackwardsTranslations) { - panic(fmt.Sprint("invalid errno: ", errno)) + if int(err.errno) >= len(linuxBackwardsTranslations) { + panic(fmt.Sprint("invalid errno: ", err.errno)) } - e := error(unix.Errno(errno)) + e := error(unix.Errno(err.errno)) // syserror.ErrWouldBlock gets translated to syserror.EWOULDBLOCK and // enables proper blocking semantics. This should temporary address the // class of blocking bugs that keep popping up with the current state of // the error space. - if e == syserror.EWOULDBLOCK { + if err.errno == linuxerr.EWOULDBLOCK.Errno() { e = syserror.ErrWouldBlock } - linuxBackwardsTranslations[errno] = linuxBackwardsTranslation{err: e, ok: true} + linuxBackwardsTranslations[err.errno] = linuxBackwardsTranslation{err: e, ok: true} return err } @@ -69,7 +70,7 @@ func New(message string, linuxTranslation linux.Errno) *Error { // NewDynamic should only be used sparingly and not be used for static error // messages. Errors with static error messages should be declared with New as // global variables. -func NewDynamic(message string, linuxTranslation linux.Errno) *Error { +func NewDynamic(message string, linuxTranslation errno.Errno) *Error { return &Error{message: message, errno: linuxTranslation} } @@ -82,7 +83,7 @@ func NewWithoutTranslation(message string) *Error { return &Error{message: message, noTranslation: true} } -func newWithHost(message string, linuxTranslation linux.Errno, hostErrno unix.Errno) *Error { +func newWithHost(message string, linuxTranslation errno.Errno, hostErrno unix.Errno) *Error { e := New(message, linuxTranslation) addLinuxHostTranslation(hostErrno, e) return e @@ -114,19 +115,19 @@ func (e *Error) ToError() error { if e.noTranslation { panic(fmt.Sprintf("error %q does not support translation", e.message)) } - errno := int(e.errno) - if errno == linux.NOERRNO { + err := int(e.errno) + if err == errno.NOERRNO { return nil } - if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok { - panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno)) + if err >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[err].ok { + panic(fmt.Sprintf("unknown error %q (%d)", e.message, err)) } - return linuxBackwardsTranslations[errno].err + return linuxBackwardsTranslations[err].err } // ToLinux converts the Error to a Linux ABI error that can be returned to the // application. -func (e *Error) ToLinux() linux.Errno { +func (e *Error) ToLinux() errno.Errno { if e.noTranslation { panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message)) } @@ -138,137 +139,137 @@ func (e *Error) ToLinux() linux.Errno { // Some of the errors should be replaced with package specific errors and // others should be removed entirely. var ( - ErrNotPermitted = newWithHost("operation not permitted", linux.EPERM, unix.EPERM) - ErrNoFileOrDir = newWithHost("no such file or directory", linux.ENOENT, unix.ENOENT) - ErrNoProcess = newWithHost("no such process", linux.ESRCH, unix.ESRCH) - ErrInterrupted = newWithHost("interrupted system call", linux.EINTR, unix.EINTR) - ErrIO = newWithHost("I/O error", linux.EIO, unix.EIO) - ErrDeviceOrAddress = newWithHost("no such device or address", linux.ENXIO, unix.ENXIO) - ErrTooManyArgs = newWithHost("argument list too long", linux.E2BIG, unix.E2BIG) - ErrEcec = newWithHost("exec format error", linux.ENOEXEC, unix.ENOEXEC) - ErrBadFD = newWithHost("bad file number", linux.EBADF, unix.EBADF) - ErrNoChild = newWithHost("no child processes", linux.ECHILD, unix.ECHILD) - ErrTryAgain = newWithHost("try again", linux.EAGAIN, unix.EAGAIN) - ErrNoMemory = newWithHost("out of memory", linux.ENOMEM, unix.ENOMEM) - ErrPermissionDenied = newWithHost("permission denied", linux.EACCES, unix.EACCES) - ErrBadAddress = newWithHost("bad address", linux.EFAULT, unix.EFAULT) - ErrNotBlockDevice = newWithHost("block device required", linux.ENOTBLK, unix.ENOTBLK) - ErrBusy = newWithHost("device or resource busy", linux.EBUSY, unix.EBUSY) - ErrExists = newWithHost("file exists", linux.EEXIST, unix.EEXIST) - ErrCrossDeviceLink = newWithHost("cross-device link", linux.EXDEV, unix.EXDEV) - ErrNoDevice = newWithHost("no such device", linux.ENODEV, unix.ENODEV) - ErrNotDir = newWithHost("not a directory", linux.ENOTDIR, unix.ENOTDIR) - ErrIsDir = newWithHost("is a directory", linux.EISDIR, unix.EISDIR) - ErrInvalidArgument = newWithHost("invalid argument", linux.EINVAL, unix.EINVAL) - ErrFileTableOverflow = newWithHost("file table overflow", linux.ENFILE, unix.ENFILE) - ErrTooManyOpenFiles = newWithHost("too many open files", linux.EMFILE, unix.EMFILE) - ErrNotTTY = newWithHost("not a typewriter", linux.ENOTTY, unix.ENOTTY) - ErrTestFileBusy = newWithHost("text file busy", linux.ETXTBSY, unix.ETXTBSY) - ErrFileTooBig = newWithHost("file too large", linux.EFBIG, unix.EFBIG) - ErrNoSpace = newWithHost("no space left on device", linux.ENOSPC, unix.ENOSPC) - ErrIllegalSeek = newWithHost("illegal seek", linux.ESPIPE, unix.ESPIPE) - ErrReadOnlyFS = newWithHost("read-only file system", linux.EROFS, unix.EROFS) - ErrTooManyLinks = newWithHost("too many links", linux.EMLINK, unix.EMLINK) - ErrBrokenPipe = newWithHost("broken pipe", linux.EPIPE, unix.EPIPE) - ErrDomain = newWithHost("math argument out of domain of func", linux.EDOM, unix.EDOM) - ErrRange = newWithHost("math result not representable", linux.ERANGE, unix.ERANGE) - ErrDeadlock = newWithHost("resource deadlock would occur", linux.EDEADLOCK, unix.EDEADLOCK) - ErrNameTooLong = newWithHost("file name too long", linux.ENAMETOOLONG, unix.ENAMETOOLONG) - ErrNoLocksAvailable = newWithHost("no record locks available", linux.ENOLCK, unix.ENOLCK) - ErrInvalidSyscall = newWithHost("invalid system call number", linux.ENOSYS, unix.ENOSYS) - ErrDirNotEmpty = newWithHost("directory not empty", linux.ENOTEMPTY, unix.ENOTEMPTY) - ErrLinkLoop = newWithHost("too many symbolic links encountered", linux.ELOOP, unix.ELOOP) - ErrNoMessage = newWithHost("no message of desired type", linux.ENOMSG, unix.ENOMSG) - ErrIdentifierRemoved = newWithHost("identifier removed", linux.EIDRM, unix.EIDRM) - ErrChannelOutOfRange = newWithHost("channel number out of range", linux.ECHRNG, unix.ECHRNG) - ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", linux.EL2NSYNC, unix.EL2NSYNC) - ErrLevelThreeHalted = newWithHost("level 3 halted", linux.EL3HLT, unix.EL3HLT) - ErrLevelThreeReset = newWithHost("level 3 reset", linux.EL3RST, unix.EL3RST) - ErrLinkNumberOutOfRange = newWithHost("link number out of range", linux.ELNRNG, unix.ELNRNG) - ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", linux.EUNATCH, unix.EUNATCH) - ErrNoCSIAvailable = newWithHost("no CSI structure available", linux.ENOCSI, unix.ENOCSI) - ErrLevelTwoHalted = newWithHost("level 2 halted", linux.EL2HLT, unix.EL2HLT) - ErrInvalidExchange = newWithHost("invalid exchange", linux.EBADE, unix.EBADE) - ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", linux.EBADR, unix.EBADR) - ErrExchangeFull = newWithHost("exchange full", linux.EXFULL, unix.EXFULL) - ErrNoAnode = newWithHost("no anode", linux.ENOANO, unix.ENOANO) - ErrInvalidRequestCode = newWithHost("invalid request code", linux.EBADRQC, unix.EBADRQC) - ErrInvalidSlot = newWithHost("invalid slot", linux.EBADSLT, unix.EBADSLT) - ErrBadFontFile = newWithHost("bad font file format", linux.EBFONT, unix.EBFONT) - ErrNotStream = newWithHost("device not a stream", linux.ENOSTR, unix.ENOSTR) - ErrNoDataAvailable = newWithHost("no data available", linux.ENODATA, unix.ENODATA) - ErrTimerExpired = newWithHost("timer expired", linux.ETIME, unix.ETIME) - ErrStreamsResourceDepleted = newWithHost("out of streams resources", linux.ENOSR, unix.ENOSR) - ErrMachineNotOnNetwork = newWithHost("machine is not on the network", linux.ENONET, unix.ENONET) - ErrPackageNotInstalled = newWithHost("package not installed", linux.ENOPKG, unix.ENOPKG) - ErrIsRemote = newWithHost("object is remote", linux.EREMOTE, unix.EREMOTE) - ErrNoLink = newWithHost("link has been severed", linux.ENOLINK, unix.ENOLINK) - ErrAdvertise = newWithHost("advertise error", linux.EADV, unix.EADV) - ErrSRMount = newWithHost("srmount error", linux.ESRMNT, unix.ESRMNT) - ErrSendCommunication = newWithHost("communication error on send", linux.ECOMM, unix.ECOMM) - ErrProtocol = newWithHost("protocol error", linux.EPROTO, unix.EPROTO) - ErrMultihopAttempted = newWithHost("multihop attempted", linux.EMULTIHOP, unix.EMULTIHOP) - ErrRFS = newWithHost("RFS specific error", linux.EDOTDOT, unix.EDOTDOT) - ErrInvalidDataMessage = newWithHost("not a data message", linux.EBADMSG, unix.EBADMSG) - ErrOverflow = newWithHost("value too large for defined data type", linux.EOVERFLOW, unix.EOVERFLOW) - ErrNetworkNameNotUnique = newWithHost("name not unique on network", linux.ENOTUNIQ, unix.ENOTUNIQ) - ErrFDInBadState = newWithHost("file descriptor in bad state", linux.EBADFD, unix.EBADFD) - ErrRemoteAddressChanged = newWithHost("remote address changed", linux.EREMCHG, unix.EREMCHG) - ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", linux.ELIBACC, unix.ELIBACC) - ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", linux.ELIBBAD, unix.ELIBBAD) - ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", linux.ELIBSCN, unix.ELIBSCN) - ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", linux.ELIBMAX, unix.ELIBMAX) - ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", linux.ELIBEXEC, unix.ELIBEXEC) - ErrIllegalByteSequence = newWithHost("illegal byte sequence", linux.EILSEQ, unix.EILSEQ) - ErrShouldRestart = newWithHost("interrupted system call should be restarted", linux.ERESTART, unix.ERESTART) - ErrStreamPipe = newWithHost("streams pipe error", linux.ESTRPIPE, unix.ESTRPIPE) - ErrTooManyUsers = newWithHost("too many users", linux.EUSERS, unix.EUSERS) - ErrNotASocket = newWithHost("socket operation on non-socket", linux.ENOTSOCK, unix.ENOTSOCK) - ErrDestinationAddressRequired = newWithHost("destination address required", linux.EDESTADDRREQ, unix.EDESTADDRREQ) - ErrMessageTooLong = newWithHost("message too long", linux.EMSGSIZE, unix.EMSGSIZE) - ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", linux.EPROTOTYPE, unix.EPROTOTYPE) - ErrProtocolNotAvailable = newWithHost("protocol not available", linux.ENOPROTOOPT, unix.ENOPROTOOPT) - ErrProtocolNotSupported = newWithHost("protocol not supported", linux.EPROTONOSUPPORT, unix.EPROTONOSUPPORT) - ErrSocketNotSupported = newWithHost("socket type not supported", linux.ESOCKTNOSUPPORT, unix.ESOCKTNOSUPPORT) - ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", linux.EOPNOTSUPP, unix.EOPNOTSUPP) - ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", linux.EPFNOSUPPORT, unix.EPFNOSUPPORT) - ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", linux.EAFNOSUPPORT, unix.EAFNOSUPPORT) - ErrAddressInUse = newWithHost("address already in use", linux.EADDRINUSE, unix.EADDRINUSE) - ErrAddressNotAvailable = newWithHost("cannot assign requested address", linux.EADDRNOTAVAIL, unix.EADDRNOTAVAIL) - ErrNetworkDown = newWithHost("network is down", linux.ENETDOWN, unix.ENETDOWN) - ErrNetworkUnreachable = newWithHost("network is unreachable", linux.ENETUNREACH, unix.ENETUNREACH) - ErrNetworkReset = newWithHost("network dropped connection because of reset", linux.ENETRESET, unix.ENETRESET) - ErrConnectionAborted = newWithHost("software caused connection abort", linux.ECONNABORTED, unix.ECONNABORTED) - ErrConnectionReset = newWithHost("connection reset by peer", linux.ECONNRESET, unix.ECONNRESET) - ErrNoBufferSpace = newWithHost("no buffer space available", linux.ENOBUFS, unix.ENOBUFS) - ErrAlreadyConnected = newWithHost("transport endpoint is already connected", linux.EISCONN, unix.EISCONN) - ErrNotConnected = newWithHost("transport endpoint is not connected", linux.ENOTCONN, unix.ENOTCONN) - ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", linux.ESHUTDOWN, unix.ESHUTDOWN) - ErrTooManyRefs = newWithHost("too many references: cannot splice", linux.ETOOMANYREFS, unix.ETOOMANYREFS) - ErrTimedOut = newWithHost("connection timed out", linux.ETIMEDOUT, unix.ETIMEDOUT) - ErrConnectionRefused = newWithHost("connection refused", linux.ECONNREFUSED, unix.ECONNREFUSED) - ErrHostDown = newWithHost("host is down", linux.EHOSTDOWN, unix.EHOSTDOWN) - ErrNoRoute = newWithHost("no route to host", linux.EHOSTUNREACH, unix.EHOSTUNREACH) - ErrAlreadyInProgress = newWithHost("operation already in progress", linux.EALREADY, unix.EALREADY) - ErrInProgress = newWithHost("operation now in progress", linux.EINPROGRESS, unix.EINPROGRESS) - ErrStaleFileHandle = newWithHost("stale file handle", linux.ESTALE, unix.ESTALE) - ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", linux.EUCLEAN, unix.EUCLEAN) - ErrIsNamedFile = newWithHost("is a named type file", linux.ENOTNAM, unix.ENOTNAM) - ErrRemoteIO = newWithHost("remote I/O error", linux.EREMOTEIO, unix.EREMOTEIO) - ErrQuotaExceeded = newWithHost("quota exceeded", linux.EDQUOT, unix.EDQUOT) - ErrNoMedium = newWithHost("no medium found", linux.ENOMEDIUM, unix.ENOMEDIUM) - ErrWrongMediumType = newWithHost("wrong medium type", linux.EMEDIUMTYPE, unix.EMEDIUMTYPE) - ErrCanceled = newWithHost("operation canceled", linux.ECANCELED, unix.ECANCELED) - ErrNoKey = newWithHost("required key not available", linux.ENOKEY, unix.ENOKEY) - ErrKeyExpired = newWithHost("key has expired", linux.EKEYEXPIRED, unix.EKEYEXPIRED) - ErrKeyRevoked = newWithHost("key has been revoked", linux.EKEYREVOKED, unix.EKEYREVOKED) - ErrKeyRejected = newWithHost("key was rejected by service", linux.EKEYREJECTED, unix.EKEYREJECTED) - ErrOwnerDied = newWithHost("owner died", linux.EOWNERDEAD, unix.EOWNERDEAD) - ErrNotRecoverable = newWithHost("state not recoverable", linux.ENOTRECOVERABLE, unix.ENOTRECOVERABLE) + ErrNotPermitted = newWithHost("operation not permitted", errno.EPERM, unix.EPERM) + ErrNoFileOrDir = newWithHost("no such file or directory", errno.ENOENT, unix.ENOENT) + ErrNoProcess = newWithHost("no such process", errno.ESRCH, unix.ESRCH) + ErrInterrupted = newWithHost("interrupted system call", errno.EINTR, unix.EINTR) + ErrIO = newWithHost("I/O error", errno.EIO, unix.EIO) + ErrDeviceOrAddress = newWithHost("no such device or address", errno.ENXIO, unix.ENXIO) + ErrTooManyArgs = newWithHost("argument list too long", errno.E2BIG, unix.E2BIG) + ErrEcec = newWithHost("exec format error", errno.ENOEXEC, unix.ENOEXEC) + ErrBadFD = newWithHost("bad file number", errno.EBADF, unix.EBADF) + ErrNoChild = newWithHost("no child processes", errno.ECHILD, unix.ECHILD) + ErrTryAgain = newWithHost("try again", errno.EAGAIN, unix.EAGAIN) + ErrNoMemory = newWithHost("out of memory", errno.ENOMEM, unix.ENOMEM) + ErrPermissionDenied = newWithHost("permission denied", errno.EACCES, unix.EACCES) + ErrBadAddress = newWithHost("bad address", errno.EFAULT, unix.EFAULT) + ErrNotBlockDevice = newWithHost("block device required", errno.ENOTBLK, unix.ENOTBLK) + ErrBusy = newWithHost("device or resource busy", errno.EBUSY, unix.EBUSY) + ErrExists = newWithHost("file exists", errno.EEXIST, unix.EEXIST) + ErrCrossDeviceLink = newWithHost("cross-device link", errno.EXDEV, unix.EXDEV) + ErrNoDevice = newWithHost("no such device", errno.ENODEV, unix.ENODEV) + ErrNotDir = newWithHost("not a directory", errno.ENOTDIR, unix.ENOTDIR) + ErrIsDir = newWithHost("is a directory", errno.EISDIR, unix.EISDIR) + ErrInvalidArgument = newWithHost("invalid argument", errno.EINVAL, unix.EINVAL) + ErrFileTableOverflow = newWithHost("file table overflow", errno.ENFILE, unix.ENFILE) + ErrTooManyOpenFiles = newWithHost("too many open files", errno.EMFILE, unix.EMFILE) + ErrNotTTY = newWithHost("not a typewriter", errno.ENOTTY, unix.ENOTTY) + ErrTestFileBusy = newWithHost("text file busy", errno.ETXTBSY, unix.ETXTBSY) + ErrFileTooBig = newWithHost("file too large", errno.EFBIG, unix.EFBIG) + ErrNoSpace = newWithHost("no space left on device", errno.ENOSPC, unix.ENOSPC) + ErrIllegalSeek = newWithHost("illegal seek", errno.ESPIPE, unix.ESPIPE) + ErrReadOnlyFS = newWithHost("read-only file system", errno.EROFS, unix.EROFS) + ErrTooManyLinks = newWithHost("too many links", errno.EMLINK, unix.EMLINK) + ErrBrokenPipe = newWithHost("broken pipe", errno.EPIPE, unix.EPIPE) + ErrDomain = newWithHost("math argument out of domain of func", errno.EDOM, unix.EDOM) + ErrRange = newWithHost("math result not representable", errno.ERANGE, unix.ERANGE) + ErrDeadlock = newWithHost("resource deadlock would occur", errno.EDEADLOCK, unix.EDEADLOCK) + ErrNameTooLong = newWithHost("file name too long", errno.ENAMETOOLONG, unix.ENAMETOOLONG) + ErrNoLocksAvailable = newWithHost("no record locks available", errno.ENOLCK, unix.ENOLCK) + ErrInvalidSyscall = newWithHost("invalid system call number", errno.ENOSYS, unix.ENOSYS) + ErrDirNotEmpty = newWithHost("directory not empty", errno.ENOTEMPTY, unix.ENOTEMPTY) + ErrLinkLoop = newWithHost("too many symbolic links encountered", errno.ELOOP, unix.ELOOP) + ErrNoMessage = newWithHost("no message of desired type", errno.ENOMSG, unix.ENOMSG) + ErrIdentifierRemoved = newWithHost("identifier removed", errno.EIDRM, unix.EIDRM) + ErrChannelOutOfRange = newWithHost("channel number out of range", errno.ECHRNG, unix.ECHRNG) + ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", errno.EL2NSYNC, unix.EL2NSYNC) + ErrLevelThreeHalted = newWithHost("level 3 halted", errno.EL3HLT, unix.EL3HLT) + ErrLevelThreeReset = newWithHost("level 3 reset", errno.EL3RST, unix.EL3RST) + ErrLinkNumberOutOfRange = newWithHost("link number out of range", errno.ELNRNG, unix.ELNRNG) + ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", errno.EUNATCH, unix.EUNATCH) + ErrNoCSIAvailable = newWithHost("no CSI structure available", errno.ENOCSI, unix.ENOCSI) + ErrLevelTwoHalted = newWithHost("level 2 halted", errno.EL2HLT, unix.EL2HLT) + ErrInvalidExchange = newWithHost("invalid exchange", errno.EBADE, unix.EBADE) + ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", errno.EBADR, unix.EBADR) + ErrExchangeFull = newWithHost("exchange full", errno.EXFULL, unix.EXFULL) + ErrNoAnode = newWithHost("no anode", errno.ENOANO, unix.ENOANO) + ErrInvalidRequestCode = newWithHost("invalid request code", errno.EBADRQC, unix.EBADRQC) + ErrInvalidSlot = newWithHost("invalid slot", errno.EBADSLT, unix.EBADSLT) + ErrBadFontFile = newWithHost("bad font file format", errno.EBFONT, unix.EBFONT) + ErrNotStream = newWithHost("device not a stream", errno.ENOSTR, unix.ENOSTR) + ErrNoDataAvailable = newWithHost("no data available", errno.ENODATA, unix.ENODATA) + ErrTimerExpired = newWithHost("timer expired", errno.ETIME, unix.ETIME) + ErrStreamsResourceDepleted = newWithHost("out of streams resources", errno.ENOSR, unix.ENOSR) + ErrMachineNotOnNetwork = newWithHost("machine is not on the network", errno.ENONET, unix.ENONET) + ErrPackageNotInstalled = newWithHost("package not installed", errno.ENOPKG, unix.ENOPKG) + ErrIsRemote = newWithHost("object is remote", errno.EREMOTE, unix.EREMOTE) + ErrNoLink = newWithHost("link has been severed", errno.ENOLINK, unix.ENOLINK) + ErrAdvertise = newWithHost("advertise error", errno.EADV, unix.EADV) + ErrSRMount = newWithHost("srmount error", errno.ESRMNT, unix.ESRMNT) + ErrSendCommunication = newWithHost("communication error on send", errno.ECOMM, unix.ECOMM) + ErrProtocol = newWithHost("protocol error", errno.EPROTO, unix.EPROTO) + ErrMultihopAttempted = newWithHost("multihop attempted", errno.EMULTIHOP, unix.EMULTIHOP) + ErrRFS = newWithHost("RFS specific error", errno.EDOTDOT, unix.EDOTDOT) + ErrInvalidDataMessage = newWithHost("not a data message", errno.EBADMSG, unix.EBADMSG) + ErrOverflow = newWithHost("value too large for defined data type", errno.EOVERFLOW, unix.EOVERFLOW) + ErrNetworkNameNotUnique = newWithHost("name not unique on network", errno.ENOTUNIQ, unix.ENOTUNIQ) + ErrFDInBadState = newWithHost("file descriptor in bad state", errno.EBADFD, unix.EBADFD) + ErrRemoteAddressChanged = newWithHost("remote address changed", errno.EREMCHG, unix.EREMCHG) + ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", errno.ELIBACC, unix.ELIBACC) + ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", errno.ELIBBAD, unix.ELIBBAD) + ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", errno.ELIBSCN, unix.ELIBSCN) + ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", errno.ELIBMAX, unix.ELIBMAX) + ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", errno.ELIBEXEC, unix.ELIBEXEC) + ErrIllegalByteSequence = newWithHost("illegal byte sequence", errno.EILSEQ, unix.EILSEQ) + ErrShouldRestart = newWithHost("interrupted system call should be restarted", errno.ERESTART, unix.ERESTART) + ErrStreamPipe = newWithHost("streams pipe error", errno.ESTRPIPE, unix.ESTRPIPE) + ErrTooManyUsers = newWithHost("too many users", errno.EUSERS, unix.EUSERS) + ErrNotASocket = newWithHost("socket operation on non-socket", errno.ENOTSOCK, unix.ENOTSOCK) + ErrDestinationAddressRequired = newWithHost("destination address required", errno.EDESTADDRREQ, unix.EDESTADDRREQ) + ErrMessageTooLong = newWithHost("message too long", errno.EMSGSIZE, unix.EMSGSIZE) + ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", errno.EPROTOTYPE, unix.EPROTOTYPE) + ErrProtocolNotAvailable = newWithHost("protocol not available", errno.ENOPROTOOPT, unix.ENOPROTOOPT) + ErrProtocolNotSupported = newWithHost("protocol not supported", errno.EPROTONOSUPPORT, unix.EPROTONOSUPPORT) + ErrSocketNotSupported = newWithHost("socket type not supported", errno.ESOCKTNOSUPPORT, unix.ESOCKTNOSUPPORT) + ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", errno.EOPNOTSUPP, unix.EOPNOTSUPP) + ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", errno.EPFNOSUPPORT, unix.EPFNOSUPPORT) + ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", errno.EAFNOSUPPORT, unix.EAFNOSUPPORT) + ErrAddressInUse = newWithHost("address already in use", errno.EADDRINUSE, unix.EADDRINUSE) + ErrAddressNotAvailable = newWithHost("cannot assign requested address", errno.EADDRNOTAVAIL, unix.EADDRNOTAVAIL) + ErrNetworkDown = newWithHost("network is down", errno.ENETDOWN, unix.ENETDOWN) + ErrNetworkUnreachable = newWithHost("network is unreachable", errno.ENETUNREACH, unix.ENETUNREACH) + ErrNetworkReset = newWithHost("network dropped connection because of reset", errno.ENETRESET, unix.ENETRESET) + ErrConnectionAborted = newWithHost("software caused connection abort", errno.ECONNABORTED, unix.ECONNABORTED) + ErrConnectionReset = newWithHost("connection reset by peer", errno.ECONNRESET, unix.ECONNRESET) + ErrNoBufferSpace = newWithHost("no buffer space available", errno.ENOBUFS, unix.ENOBUFS) + ErrAlreadyConnected = newWithHost("transport endpoint is already connected", errno.EISCONN, unix.EISCONN) + ErrNotConnected = newWithHost("transport endpoint is not connected", errno.ENOTCONN, unix.ENOTCONN) + ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", errno.ESHUTDOWN, unix.ESHUTDOWN) + ErrTooManyRefs = newWithHost("too many references: cannot splice", errno.ETOOMANYREFS, unix.ETOOMANYREFS) + ErrTimedOut = newWithHost("connection timed out", errno.ETIMEDOUT, unix.ETIMEDOUT) + ErrConnectionRefused = newWithHost("connection refused", errno.ECONNREFUSED, unix.ECONNREFUSED) + ErrHostDown = newWithHost("host is down", errno.EHOSTDOWN, unix.EHOSTDOWN) + ErrNoRoute = newWithHost("no route to host", errno.EHOSTUNREACH, unix.EHOSTUNREACH) + ErrAlreadyInProgress = newWithHost("operation already in progress", errno.EALREADY, unix.EALREADY) + ErrInProgress = newWithHost("operation now in progress", errno.EINPROGRESS, unix.EINPROGRESS) + ErrStaleFileHandle = newWithHost("stale file handle", errno.ESTALE, unix.ESTALE) + ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", errno.EUCLEAN, unix.EUCLEAN) + ErrIsNamedFile = newWithHost("is a named type file", errno.ENOTNAM, unix.ENOTNAM) + ErrRemoteIO = newWithHost("remote I/O error", errno.EREMOTEIO, unix.EREMOTEIO) + ErrQuotaExceeded = newWithHost("quota exceeded", errno.EDQUOT, unix.EDQUOT) + ErrNoMedium = newWithHost("no medium found", errno.ENOMEDIUM, unix.ENOMEDIUM) + ErrWrongMediumType = newWithHost("wrong medium type", errno.EMEDIUMTYPE, unix.EMEDIUMTYPE) + ErrCanceled = newWithHost("operation canceled", errno.ECANCELED, unix.ECANCELED) + ErrNoKey = newWithHost("required key not available", errno.ENOKEY, unix.ENOKEY) + ErrKeyExpired = newWithHost("key has expired", errno.EKEYEXPIRED, unix.EKEYEXPIRED) + ErrKeyRevoked = newWithHost("key has been revoked", errno.EKEYREVOKED, unix.EKEYREVOKED) + ErrKeyRejected = newWithHost("key was rejected by service", errno.EKEYREJECTED, unix.EKEYREJECTED) + ErrOwnerDied = newWithHost("owner died", errno.EOWNERDEAD, unix.EOWNERDEAD) + ErrNotRecoverable = newWithHost("state not recoverable", errno.ENOTRECOVERABLE, unix.ENOTRECOVERABLE) // ErrWouldBlock translates to EWOULDBLOCK which is the same as EAGAIN // on Linux. - ErrWouldBlock = New("operation would block", linux.EWOULDBLOCK) + ErrWouldBlock = New("operation would block", errno.EWOULDBLOCK) ) // FromError converts a generic error to an *Error. @@ -281,6 +282,11 @@ func FromError(err error) *Error { if errno, ok := err.(unix.Errno); ok { return FromHost(errno) } + + if linuxErr, ok := err.(*errors.Error); ok { + return FromHost(unix.Errno(linuxErr.Errno())) + } + if errno, ok := syserror.TranslateError(err); ok { return FromHost(errno) } diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index 7d2f5adf6..76bee5a64 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,12 +8,3 @@ go_library( visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) - -go_test( - name = "syserror_test", - srcs = ["syserror_test.go"], - deps = [ - ":syserror", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 56b621357..721f62903 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -26,9 +26,7 @@ import ( // The following variables have the same meaning as their syscall equivalent. var ( - E2BIG = error(unix.E2BIG) EACCES = error(unix.EACCES) - EADDRINUSE = error(unix.EADDRINUSE) EAGAIN = error(unix.EAGAIN) EBADF = error(unix.EBADF) EBADFD = error(unix.EBADFD) @@ -43,7 +41,6 @@ var ( EFBIG = error(unix.EFBIG) EIDRM = error(unix.EIDRM) EINTR = error(unix.EINTR) - EINVAL = error(unix.EINVAL) EIO = error(unix.EIO) EISDIR = error(unix.EISDIR) ELIBBAD = error(unix.ELIBBAD) diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go deleted file mode 100644 index c141e5f6e..000000000 --- a/pkg/syserror/syserror_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syserror_test - -import ( - "errors" - "testing" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/syserror" -) - -var globalError error - -func BenchmarkAssignErrno(b *testing.B) { - for i := b.N; i > 0; i-- { - globalError = unix.EINVAL - } -} - -func BenchmarkAssignError(b *testing.B) { - for i := b.N; i > 0; i-- { - globalError = syserror.EINVAL - } -} - -func BenchmarkCompareErrno(b *testing.B) { - globalError = unix.EAGAIN - j := 0 - for i := b.N; i > 0; i-- { - if globalError == unix.EINVAL { - j++ - } - } -} - -func BenchmarkCompareError(b *testing.B) { - globalError = syserror.EAGAIN - j := 0 - for i := b.N; i > 0; i-- { - if globalError == syserror.EINVAL { - j++ - } - } -} - -func BenchmarkSwitchErrno(b *testing.B) { - globalError = unix.EPERM - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case unix.EINVAL: - j += 1 - case unix.EINTR: - j += 2 - case unix.EAGAIN: - j += 3 - } - } -} - -func BenchmarkSwitchError(b *testing.B) { - globalError = syserror.EPERM - j := 0 - for i := b.N; i > 0; i-- { - switch globalError { - case syserror.EINVAL: - j += 1 - case syserror.EINTR: - j += 2 - case syserror.EAGAIN: - j += 3 - } - } -} - -type translationTestTable struct { - fn string - errIn error - syscallErrorIn unix.Errno - expectedBool bool - expectedTranslation unix.Errno -} - -func TestErrorTranslation(t *testing.T) { - myError := errors.New("My test error") - myError2 := errors.New("Another test error") - testTable := []translationTestTable{ - {"TranslateError", myError, 0, false, 0}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError, unix.EAGAIN, true, 0}, - {"AddErrorTranslation", myError, unix.EAGAIN, false, 0}, - {"AddErrorTranslation", myError, unix.EPERM, false, 0}, - {"TranslateError", myError, 0, true, unix.EAGAIN}, - {"TranslateError", myError2, 0, false, 0}, - {"AddErrorTranslation", myError2, unix.EPERM, true, 0}, - {"AddErrorTranslation", myError2, unix.EPERM, false, 0}, - {"AddErrorTranslation", myError2, unix.EAGAIN, false, 0}, - {"TranslateError", myError, 0, true, unix.EAGAIN}, - {"TranslateError", myError2, 0, true, unix.EPERM}, - } - for _, tt := range testTable { - switch tt.fn { - case "TranslateError": - err, ok := syserror.TranslateError(tt.errIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) - } else if err != tt.expectedTranslation { - t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation) - } - case "AddErrorTranslation": - ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn) - if ok != tt.expectedBool { - t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool) - } - default: - t.Fatalf("Unknown function %v", tt.fn) - } - } -} diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index ea46c30da..f00cfd0f5 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -39,12 +39,29 @@ go_library( deps_test( name = "netstack_deps_test", allowed = [ + # gVisor deps. + "//pkg/atomicbitops", + "//pkg/buffer", + "//pkg/context", + "//pkg/gohacks", + "//pkg/goid", + "//pkg/ilist", + "//pkg/linewriter", + "//pkg/log", + "//pkg/rand", + "//pkg/sleep", + "//pkg/state", + "//pkg/state/wire", + "//pkg/sync", + "//pkg/waiter", + + # Other deps. "@com_github_google_btree//:go_default_library", "@org_golang_x_sys//unix:go_default_library", "@org_golang_x_time//rate:go_default_library", ], allowed_prefixes = [ - "//", + "//pkg/tcpip", "@org_golang_x_sys//internal/unsafeheader", ], targets = [ diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 18e6cc3cd..e0dfe5813 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -50,7 +50,7 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv4 := header.IPv4(b) if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") + t.Fatalf("Not a valid IPv4 packet: %x", ipv4) } if !ipv4.IsChecksumValid() { @@ -72,7 +72,7 @@ func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv6 := header.IPv6(b) if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") + t.Fatalf("Not a valid IPv6 packet: %x", ipv6) } for _, f := range checkers { @@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp if !ok { return } - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) foundTS := false tsVal := uint32(0) @@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - // TCPSACKBlockChecker creates a checker that verifies that the segment does // contain the specified SACK blocks in the TCP options. func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { @@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } var gotSACKBlocks []header.SACKBlock - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) for i := 0; i < limit; { switch opts[i] { diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index fb819d7a8..9f8f51647 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -29,14 +29,14 @@ type NullClock struct{} var _ tcpip.Clock = (*NullClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (*NullClock) NowNanoseconds() int64 { - return 0 +// Now implements tcpip.Clock.Now. +func (*NullClock) Now() time.Time { + return time.Time{} } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (*NullClock) NowMonotonic() int64 { - return 0 +func (*NullClock) NowMonotonic() tcpip.MonotonicTime { + return tcpip.MonotonicTime{} } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -118,16 +118,17 @@ func NewManualClock() *ManualClock { var _ tcpip.Clock = (*ManualClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (mc *ManualClock) NowNanoseconds() int64 { +// Now implements tcpip.Clock.Now. +func (mc *ManualClock) Now() time.Time { mc.mu.RLock() defer mc.mu.RUnlock() - return mc.mu.now.UnixNano() + return mc.mu.now } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (mc *ManualClock) NowMonotonic() int64 { - return mc.NowNanoseconds() +func (mc *ManualClock) NowMonotonic() tcpip.MonotonicTime { + var mt tcpip.MonotonicTime + return mt.Add(mc.Now().Sub(time.Unix(0, 0))) } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -218,6 +219,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { } } +// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func (mc *ManualClock) RunImmediatelyScheduledJobs() { + mc.Advance(0) +} + // Advance executes all work that have been scheduled to execute within d from // the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go index c2704df2c..fd2bb470a 100644 --- a/pkg/tcpip/faketime/faketime_test.go +++ b/pkg/tcpip/faketime/faketime_test.go @@ -26,7 +26,7 @@ func TestManualClockAdvance(t *testing.T) { clock := faketime.NewManualClock() start := clock.NowMonotonic() clock.Advance(timeout) - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + if got, want := clock.NowMonotonic().Sub(start), timeout; got != want { t.Errorf("got = %d, want = %d", got, want) } } @@ -87,7 +87,7 @@ func TestManualClockAfterFunc(t *testing.T) { if got, want := counter2, test.wantCounter2; got != want { t.Errorf("got counter2 = %d, want = %d", got, want) } - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want { t.Errorf("got elapsed = %d, want = %d", got, want) } }) diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 6aa9acfa8..e2c85e220 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -18,6 +18,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip. return Checksum([]byte{0, uint8(protocol)}, xsum) } + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + return ChecksumCombine(xsum, ChecksumCombine(new, ^old)) +} + +// checksumUpdate2ByteAlignedAddress updates an address in a calculated +// checksum. +// +// The addresses must have the same length and must contain an even number +// of bytes. The address MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 { + const uint16Bytes = 2 + + if len(old) != len(new) { + panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new))) + } + + if len(old)%uint16Bytes != 0 { + panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old))) + } + + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + for len(old) != 0 { + // Convert the 2 byte sequences to uint16 values then apply the increment + // update. + xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1])) + old = old[uint16Bytes:] + new = new[uint16Bytes:] + } + + return xsum +} diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index d267dabd0..3445511f4 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -23,6 +23,7 @@ import ( "sync" "testing" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) { }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } + +func randomAddress(size int) tcpip.Address { + s := make([]byte, size) + for i := 0; i < size; i++ { + s[i] = byte(rand.Uint32()) + } + return tcpip.Address(s) +} + +func TestChecksummableNetworkUpdateAddress(t *testing.T) { + tests := []struct { + name string + update func(header.IPv4, tcpip.Address) + }{ + { + name: "SetSourceAddressWithChecksumUpdate", + update: header.IPv4.SetSourceAddressWithChecksumUpdate, + }, + { + name: "SetDestinationAddressWithChecksumUpdate", + update: header.IPv4.SetDestinationAddressWithChecksumUpdate, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for i := 0; i < 1000; i++ { + var origBytes [header.IPv4MinimumSize]byte + header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ + TOS: 1, + TotalLength: header.IPv4MinimumSize, + ID: 2, + Flags: 3, + FragmentOffset: 4, + TTL: 5, + Protocol: 6, + Checksum: 0, + SrcAddr: randomAddress(header.IPv4AddressSize), + DstAddr: randomAddress(header.IPv4AddressSize), + }) + + addr := randomAddress(header.IPv4AddressSize) + + bytesCopy := origBytes + h := header.IPv4(bytesCopy[:]) + origXSum := h.CalculateChecksum() + h.SetChecksum(^origXSum) + + test.update(h, addr) + got := ^h.Checksum() + h.SetChecksum(0) + want := h.CalculateChecksum() + if got != want { + t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) + } + } + }) + } +} + +func TestChecksummableTransportUpdatePort(t *testing.T) { + // The fields in the pseudo header is not tested here so we just use 0. + const pseudoHeaderXSum = 0 + + tests := []struct { + name string + transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.TCP(make([]byte, header.TCPMinimumSize)) + h.Encode(&header.TCPFields{ + SrcPort: src, + DstPort: dst, + SeqNum: 1, + AckNum: 2, + DataOffset: header.TCPMinimumSize, + Flags: 3, + WindowSize: 4, + Checksum: 0, + UrgentPointer: 5, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.UDP(make([]byte, header.UDPMinimumSize)) + h.Encode(&header.UDPFields{ + SrcPort: src, + DstPort: dst, + Length: 0, + Checksum: 0, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + origSrcPort := uint16(rand.Uint32()) + origDstPort := uint16(rand.Uint32()) + newPort := uint16(rand.Uint32()) + + t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range []struct { + name string + update func(header.ChecksummableTransport) + }{ + { + name: "Source port", + update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, + }, + { + name: "Destination port", + update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, + }, + } { + t.Run(subTest.name, func(t *testing.T) { + h, calcXSum := test.transportHdr(origSrcPort, origDstPort) + subTest.update(h) + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + got := ^h.Checksum() + h.SetChecksum(0) + + if want := calcXSum(pseudoHeaderXSum); got != want { + h, _ := test.transportHdr(origSrcPort, origDstPort) + t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) + } + }) + } + }) + } + }) + } +} + +func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { + const addressSize = 6 + + tests := []struct { + name string + transportHdr func() header.ChecksummableTransport + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + permanent := randomAddress(addressSize) + old := randomAddress(addressSize) + new := randomAddress(addressSize) + + t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, fullChecksum := range []bool{true, false} { + t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { + initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) + if fullChecksum { + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + initialXSum = ^initialXSum + } + + h := test.transportHdr() + h.SetChecksum(initialXSum) + h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) + + got := h.Checksum() + if fullChecksum { + got = ^got + } + if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { + t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go index 861cbbb70..3a41adfc4 100644 --- a/pkg/tcpip/header/interfaces.go +++ b/pkg/tcpip/header/interfaces.go @@ -53,6 +53,31 @@ type Transport interface { Payload() []byte } +// ChecksummableTransport is a Transport that supports checksumming. +type ChecksummableTransport interface { + Transport + + // SetSourcePortWithChecksumUpdate sets the source port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetSourcePortWithChecksumUpdate(port uint16) + + // SetDestinationPortWithChecksumUpdate sets the destination port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetDestinationPortWithChecksumUpdate(port uint16) + + // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an + // updated address in the pseudo header. + // + // If fullChecksum is true, the receiver's checksum field is assumed to hold a + // fully calculated checksum. Otherwise, it is assumed to hold a partially + // calculated checksum which only reflects the pseudo header. + UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) +} + // Network offers generic methods to query and/or update the fields of the // header of a network protocol buffer. type Network interface { @@ -90,3 +115,16 @@ type Network interface { // SetTOS sets the values of the "type of service" and "flow label" fields. SetTOS(t uint8, l uint32) } + +// ChecksummableNetwork is a Network that supports checksumming. +type ChecksummableNetwork interface { + Network + + // SetSourceAddressAndChecksum sets the source address and updates the + // checksum to reflect the new address. + SetSourceAddressWithChecksumUpdate(tcpip.Address) + + // SetDestinationAddressAndChecksum sets the destination address and + // updates the checksum to reflect the new address. + SetDestinationAddressWithChecksumUpdate(tcpip.Address) +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 2be21ec75..dcc549c7b 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -181,7 +182,7 @@ const ( // ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined // by RFC 3927 section 1. var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00")) + subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", "\xff\xff\x00\x00") if err != nil { panic(err) } @@ -191,7 +192,7 @@ var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { // ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as // defined by RFC 5771 section 4. var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00")) + subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", "\xff\xff\xff\x00") if err != nil { panic(err) } @@ -304,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new)) + b.SetSourceAddress(new) +} + +// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new)) + b.SetDestinationAddress(new) +} + // padIPv4OptionsLength returns the total length for IPv4 options of length l // after applying padding according to RFC 791: // The internet header padding is used to ensure that the internet @@ -572,7 +585,7 @@ func (o *IPv4OptionGeneric) Type() IPv4OptionType { func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } // Contents implements IPv4Option. -func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } +func (o *IPv4OptionGeneric) Contents() []byte { return *o } // IPv4OptionIterator is an iterator pointing to a specific IP option // at any point of time. It also holds information as to a new options buffer @@ -610,7 +623,7 @@ func (i *IPv4OptionIterator) InitReplacement(option IPv4Option) IPv4Options { // RemainingBuffer returns the remaining (unused) part of the new option buffer, // into which a new option may be written. func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { - return IPv4Options(i.newOptions[i.writePoint:]) + return i.newOptions[i.writePoint:] } // ConsumeBuffer marks a portion of the new buffer as used. @@ -813,9 +826,12 @@ const ( // ipv4TimestampTime provides the current time as specified in RFC 791. func ipv4TimestampTime(clock tcpip.Clock) uint32 { - const millisecondsPerDay = 24 * 3600 * 1000 - const nanoPerMilli = 1000000 - return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) + // Per RFC 791 page 21: + // The Timestamp is a right-justified, 32-bit timestamp in + // milliseconds since midnight UT. + now := clock.Now().UTC() + midnight := now.Truncate(24 * time.Hour) + return uint32(now.Sub(midnight).Milliseconds()) } // IP Timestamp option fields. @@ -843,7 +859,7 @@ func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestam func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } // Contents implements IPv4Option. -func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } +func (ts *IPv4OptionTimestamp) Contents() []byte { return *ts } // Pointer returns the pointer field in the IP Timestamp option. func (ts *IPv4OptionTimestamp) Pointer() uint8 { @@ -947,7 +963,7 @@ func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecord func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } // Contents implements IPv4Option. -func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } +func (rr *IPv4OptionRecordRoute) Contents() []byte { return *rr } // Router Alert option specific related constants. // @@ -992,7 +1008,7 @@ func (*IPv4OptionRouterAlert) Type() IPv4OptionType { return IPv4OptionRouterAle func (ra *IPv4OptionRouterAlert) Size() uint8 { return uint8(len(*ra)) } // Contents implements IPv4Option. -func (ra *IPv4OptionRouterAlert) Contents() []byte { return []byte(*ra) } +func (ra *IPv4OptionRouterAlert) Contents() []byte { return *ra } // Value returns the value of the IPv4OptionRouterAlert. func (ra *IPv4OptionRouterAlert) Value() uint16 { diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 3d1bccd15..b1f39e6e6 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -77,12 +77,12 @@ const ( // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag // field in the flags byte within an NDPPrefixInformation. - ndpPrefixInformationOnLinkFlagMask = (1 << 7) + ndpPrefixInformationOnLinkFlagMask = 1 << 7 // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the // Autonomous Address-Configuration flag field in the flags byte within // an NDPPrefixInformation. - ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6 // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 // field in the flags byte within an NDPPrefixInformation. @@ -148,15 +148,10 @@ const ( // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. lengthByteUnits = 8 -) -var ( // NDPInfiniteLifetime is a value that represents infinity for the // 4-byte lifetime fields found in various NDP options. Its value is // (2^32 - 1)s = 4294967295s. - // - // This is a variable instead of a constant so that tests can change - // this value to a smaller value. It should only be modified by tests. NDPInfiniteLifetime = time.Second * math.MaxUint32 ) @@ -451,7 +446,7 @@ func (o NDPNonceOption) String() string { // Nonce returns the nonce value this option holds. func (o NDPNonceOption) Nonce() []byte { - return []byte(o) + return o } // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index bf7610863..7e2f0c797 100644 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -19,12 +19,72 @@ import ( "time" ) +// NDPRoutePreference is the preference values for default routers or +// more-specific routes. +// +// As per RFC 4191 section 2.1, +// +// Default router preferences and preferences for more-specific routes +// are encoded the same way. +// +// Preference values are encoded as a two-bit signed integer, as +// follows: +// +// 01 High +// 00 Medium (default) +// 11 Low +// 10 Reserved - MUST NOT be sent +// +// Note that implementations can treat the value as a two-bit signed +// integer. +// +// Having just three values reinforces that they are not metrics and +// more values do not appear to be necessary for reasonable scenarios. +type NDPRoutePreference uint8 + +const ( + // HighRoutePreference indicates a high preference, as per + // RFC 4191 section 2.1. + HighRoutePreference NDPRoutePreference = 0b01 + + // MediumRoutePreference indicates a medium preference, as per + // RFC 4191 section 2.1. + // + // This is the default preference value. + MediumRoutePreference = 0b00 + + // LowRoutePreference indicates a low preference, as per + // RFC 4191 section 2.1. + LowRoutePreference = 0b11 + + // ReservedRoutePreference is a reserved preference value, as per + // RFC 4191 section 2.1. + // + // It MUST NOT be sent. + ReservedRoutePreference = 0b10 +) + // NDPRouterAdvert is an NDP Router Advertisement message. It will only contain // the body of an ICMPv6 packet. // -// See RFC 4861 section 4.2 for more details. +// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details. type NDPRouterAdvert []byte +// As per RFC 4191 section 2.2, +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Reachable Time | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Retrans Timer | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options ... +// +-+-+-+-+-+-+-+-+-+-+-+- const ( // NDPRAMinimumSize is the minimum size of a valid NDP Router // Advertisement message (body of an ICMPv6 packet). @@ -47,6 +107,14 @@ const ( // within the bit-field/flags byte of an NDPRouterAdvert. ndpRAOtherConfFlagMask = (1 << 6) + // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceShift = 3 + + // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift) + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime // field within an NDPRouterAdvert. ndpRARouterLifetimeOffset = 2 @@ -80,6 +148,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool { return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 } +// DefaultRouterPreference returns the Default Router Preference field. +func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference { + return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift) +} + // RouterLifetime returns the lifetime associated with the default router. A // value of 0 means the source of the Router Advertisement is not a default // router and SHOULD NOT appear on the default router list. Note, a value of 0 diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 1b5093e58..8fd1f7d13 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -126,36 +126,83 @@ func TestNDPNeighborAdvert(t *testing.T) { } func TestNDPRouterAdvert(t *testing.T) { - b := []byte{ - 64, 128, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, + tests := []struct { + hopLimit uint8 + managedFlag, otherConfFlag bool + prf NDPRoutePreference + routerLifetimeS uint16 + reachableTimeMS, retransTimerMS uint32 + }{ + { + hopLimit: 1, + managedFlag: false, + otherConfFlag: true, + prf: HighRoutePreference, + routerLifetimeS: 2, + reachableTimeMS: 3, + retransTimerMS: 4, + }, + { + hopLimit: 64, + managedFlag: true, + otherConfFlag: false, + prf: LowRoutePreference, + routerLifetimeS: 258, + reachableTimeMS: 78492, + retransTimerMS: 13213, + }, } - ra := NDPRouterAdvert(b) + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + flags := uint8(0) + if test.managedFlag { + flags |= 1 << 7 + } + if test.otherConfFlag { + flags |= 1 << 6 + } + flags |= uint8(test.prf) << 3 - if got := ra.CurrHopLimit(); got != 64 { - t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) - } + b := []byte{ + test.hopLimit, flags, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS) + binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS) + binary.BigEndian.PutUint32(b[8:], test.retransTimerMS) - if got := ra.ManagedAddrConfFlag(); !got { - t.Errorf("got ManagedAddrConfFlag = false, want = true") - } + ra := NDPRouterAdvert(b) - if got := ra.OtherConfFlag(); got { - t.Errorf("got OtherConfFlag = true, want = false") - } + if got := ra.CurrHopLimit(); got != test.hopLimit { + t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit) + } - if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) - } + if got := ra.ManagedAddrConfFlag(); got != test.managedFlag { + t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag) + } - if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) - } + if got := ra.OtherConfFlag(); got != test.otherConfFlag { + t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag) + } + + if got := ra.DefaultRouterPreference(); got != test.prf { + t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf) + } - if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want { + t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want { + t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want { + t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want) + } + }) } } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 0df517000..a75e51a28 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -48,6 +48,16 @@ const ( // TCPFlags is the dedicated type for TCP flags. type TCPFlags uint8 +// Intersects returns true iff there are flags common to both f and o. +func (f TCPFlags) Intersects(o TCPFlags) bool { + return f&o != 0 +} + +// Contains returns true iff all the flags in o are contained within f. +func (f TCPFlags) Contains(o TCPFlags) bool { + return f&o == o +} + // String implements Stringer.String. func (f TCPFlags) String() string { flagsStr := []byte("FSRPAU") @@ -380,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 b.SetChecksum(^checksum) } +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} + // ParseSynOptions parses the options received in a SYN segment and returns the // relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index ae9d167ff..f69d53314 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpLength:], u.Length) binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) } + +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index ef9126deb..f26c857eb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -288,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index d971194e6..1d0163823 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,7 +14,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index bddb1d0a2..1b56d2b72 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,11 +41,9 @@ package fdbased import ( "fmt" - "math" "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -139,6 +137,20 @@ type endpoint struct { // gsoKind is the supported kind of GSO. gsoKind stack.SupportedGSO + + // maxSyscallHeaderBytes has the same meaning as + // Options.MaxSyscallHeaderBytes. + maxSyscallHeaderBytes uintptr + + // writevMaxIovs is the maximum number of iovecs that may be passed to + // rawfile.NonBlockingWriteIovec, as possibly limited by + // maxSyscallHeaderBytes. (No analogous limit is defined for + // rawfile.NonBlockingSendMMsg, since in that case the maximum number of + // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch + // encounters a packet whose iovec count is limited by + // maxSyscallHeaderBytes, it falls back to writing the packet using writev + // via WritePacket.) + writevMaxIovs int } // Options specify the details about the fd-based endpoint to be created. @@ -187,6 +199,11 @@ type Options struct { // RXChecksumOffload if true, indicates that this endpoints capability // set should include CapabilityRXChecksumOffload. RXChecksumOffload bool + + // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes + // of struct iovec, msghdr, and mmsghdr that may be passed by each host + // system call. + MaxSyscallHeaderBytes int } // fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT @@ -196,8 +213,12 @@ type Options struct { // option for an FD with a fanoutID already in use by another FD for a different // NIC will return an EINVAL. // +// Since fanoutID must be unique within the network namespace, we start with +// the PID to avoid collisions. The only way to be sure of avoiding collisions +// is to run in a new network namespace. +// // Must be accessed using atomic operations. -var fanoutID int32 = 0 +var fanoutID int32 = int32(unix.Getpid()) // New creates a new fd-based endpoint. // @@ -232,14 +253,25 @@ func New(opts *Options) (stack.LinkEndpoint, error) { return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } + if opts.MaxSyscallHeaderBytes < 0 { + return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative") + } + e := &endpoint{ - fds: opts.FDs, - mtu: opts.MTU, - caps: caps, - closed: opts.ClosedFunc, - addr: opts.Address, - hdrSize: hdrSize, - packetDispatchMode: opts.PacketDispatchMode, + fds: opts.FDs, + mtu: opts.MTU, + caps: caps, + closed: opts.ClosedFunc, + addr: opts.Address, + hdrSize: hdrSize, + packetDispatchMode: opts.PacketDispatchMode, + maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes), + writevMaxIovs: rawfile.MaxIovs, + } + if e.maxSyscallHeaderBytes != 0 { + if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs { + e.writevMaxIovs = max + } } // Increment fanoutID to ensure that we don't re-use the same fanoutID for @@ -292,11 +324,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin } switch sa.(type) { case *unix.SockaddrLinklayer: - // See: PACKET_FANOUT_MAX in net/packet/internal.h - const packetFanoutMax = 1 << 16 - if fID > packetFanoutMax { - return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16) - } // Enable PACKET_FANOUT mode if the underlying socket is of type // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will // prevent gvisor from receiving fragmented packets and the host does the @@ -317,7 +344,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin // // See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881 const fanoutType = unix.PACKET_FANOUT_HASH - fanoutArg := int(fID) | fanoutType<<16 + fanoutArg := (int(fID) & 0xffff) | fanoutType<<16 if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) } @@ -472,9 +499,8 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } - var builder iovec.Builder - fd := e.fds[pkt.Hash%uint32(len(e.fds))] + var vnetHdrBuf []byte if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} if pkt.GSOOptions.Type != stack.GSONone { @@ -496,71 +522,123 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol vnetHdr.gsoSize = pkt.GSOOptions.MSS } } + vnetHdrBuf = vnetHdr.marshal() + } - vnetHdrBuf := vnetHdr.marshal() - builder.Add(vnetHdrBuf) + views := pkt.Views() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > e.writevMaxIovs { + numIovecs = e.writevMaxIovs } - for _, v := range pkt.Views() { - builder.Add(v) + // Allocate small iovec arrays on the stack. + var iovecsArr [8]unix.Iovec + iovecs := iovecsArr[:0] + if numIovecs > len(iovecsArr) { + iovecs = make([]unix.Iovec, 0, numIovecs) } - return rawfile.NonBlockingWriteIovec(fd, builder.Build()) + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } + return rawfile.NonBlockingWriteIovec(fd, iovecs) } -func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) { +func (e *endpoint) sendBatch(batchFD int, pkts []*stack.PacketBuffer) (int, tcpip.Error) { // Send a batch of packets through batchFD. - mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) - for _, pkt := range batch { - if e.hdrSize > 0 { - e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) - } + mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts)) + packets := 0 + for packets < len(pkts) { + mmsgHdrs := mmsgHdrsStorage + batch := pkts[packets:] + syscallHeaderBytes := uintptr(0) + for _, pkt := range batch { + if e.hdrSize > 0 { + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) + } - var vnetHdrBuf []byte - if e.gsoKind == stack.HWGSOSupported { - vnetHdr := virtioNetHdr{} - if pkt.GSOOptions.Type != stack.GSONone { - vnetHdr.hdrLen = uint16(pkt.HeaderSize()) - if pkt.GSOOptions.NeedsCsum { - vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM - vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen - vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset - } - if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { - switch pkt.GSOOptions.Type { - case stack.GSOTCPv4: - vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 - case stack.GSOTCPv6: - vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - default: - panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + var vnetHdrBuf []byte + if e.gsoKind == stack.HWGSOSupported { + vnetHdr := virtioNetHdr{} + if pkt.GSOOptions.Type != stack.GSONone { + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) + if pkt.GSOOptions.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset + } + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + } + vnetHdr.gsoSize = pkt.GSOOptions.MSS } - vnetHdr.gsoSize = pkt.GSOOptions.MSS } + vnetHdrBuf = vnetHdr.marshal() } - vnetHdrBuf = vnetHdr.marshal() - } - var builder iovec.Builder - builder.Add(vnetHdrBuf) - for _, v := range pkt.Views() { - builder.Add(v) - } - iovecs := builder.Build() + views := pkt.Views() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > rawfile.MaxIovs { + numIovecs = rawfile.MaxIovs + } + if e.maxSyscallHeaderBytes != 0 { + syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec + if syscallHeaderBytes > e.maxSyscallHeaderBytes { + // We can't fit this packet into this call to sendmmsg(). + // We could potentially do so if we reduced numIovecs + // further, but this might incur considerable extra + // copying. Leave it to the next batch instead. + break + } + } - var mmsgHdr rawfile.MMsgHdr - mmsgHdr.Msg.Iov = &iovecs[0] - mmsgHdr.Msg.SetIovlen((len(iovecs))) - mmsgHdrs = append(mmsgHdrs, mmsgHdr) - } + // We can't easily allocate iovec arrays on the stack here since + // they will escape this loop iteration via mmsgHdrs. + iovecs := make([]unix.Iovec, 0, numIovecs) + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } - packets := 0 - for len(mmsgHdrs) > 0 { - sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) - if err != nil { - return packets, err + var mmsgHdr rawfile.MMsgHdr + mmsgHdr.Msg.Iov = &iovecs[0] + mmsgHdr.Msg.SetIovlen(len(iovecs)) + mmsgHdrs = append(mmsgHdrs, mmsgHdr) + } + + if len(mmsgHdrs) == 0 { + // We can't fit batch[0] into a mmsghdr while staying under + // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the + // mmsghdr (by using writev) and re-buffer iovecs more aggressively + // if necessary (by using e.writevMaxIovs instead of + // rawfile.MaxIovs). + pkt := batch[0] + if err := e.WritePacket(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + return packets, err + } + packets++ + } else { + for len(mmsgHdrs) > 0 { + sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } } - packets += sent - mmsgHdrs = mmsgHdrs[sent:] } return packets, nil @@ -678,8 +756,9 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti unix.SetNonblock(fd, true) return &InjectableEndpoint{endpoint: endpoint{ - fds: []int{fd}, - mtu: mtu, - caps: capabilities, + fds: []int{fd}, + mtu: mtu, + caps: capabilities, + writevMaxIovs: rawfile.MaxIovs, }} } diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index ba92aedbc..43fe57830 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -19,12 +19,66 @@ package rawfile import ( + "reflect" "unsafe" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" ) +// SizeofIovec is the size of a unix.Iovec in bytes. +const SizeofIovec = unsafe.Sizeof(unix.Iovec{}) + +// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a +// host system call in a single array. +const MaxIovs = 1024 + +// IovecFromBytes returns a unix.Iovec representing bs. +// +// Preconditions: len(bs) > 0. +func IovecFromBytes(bs []byte) unix.Iovec { + iov := unix.Iovec{ + Base: &bs[0], + } + iov.SetLen(len(bs)) + return iov +} + +func bytesFromIovec(iov unix.Iovec) (bs []byte) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + sh.Data = uintptr(unsafe.Pointer(iov.Base)) + sh.Len = int(iov.Len) + sh.Cap = int(iov.Len) + return +} + +// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) == +// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >= +// max, AppendIovecFromBytes replaces the final iovec in iovs with one that +// also includes the contents of bs. Note that this implies that +// AppendIovecFromBytes is only usable when the returned iovec slice is used as +// the source of a write. +func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec { + if len(bs) == 0 { + return iovs + } + if len(iovs) < max { + return append(iovs, IovecFromBytes(bs)) + } + iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...)) + return iovs +} + +// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux. +type MMsgHdr struct { + Msg unix.Msghdr + Len uint32 + _ [4]byte +} + +// SizeofMMsgHdr is the size of a MMsgHdr in bytes. +const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{}) + // GetMTU determines the MTU of a network interface device. func GetMTU(name string) (uint32, error) { fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0) @@ -137,13 +191,6 @@ func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) { } } -// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux. -type MMsgHdr struct { - Msg unix.Msghdr - Len uint32 - _ [4]byte -} - // BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking // and stores the received messages in a slice of MMsgHdr structures. If no data // is available, it will block in a poll() syscall until the file descriptor diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 7656cca6a..4758a99ad 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -26,6 +26,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index af9a3d759..59ac5cb8c 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -88,12 +89,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error { defer d.mu.Unlock() if d.endpoint != nil { - return syserror.EINVAL + return linuxerr.EINVAL } // Input validation. if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN { - return syserror.EINVAL + return linuxerr.EINVAL } prefix := "tun" @@ -108,7 +109,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error { endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return syserror.EINVAL + return linuxerr.EINVAL } d.endpoint = endpoint @@ -159,7 +160,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // Race detected: A NIC has been created in between. continue default: - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index a72eb1aad..6fa1aee18 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -28,13 +28,13 @@ go_test( ":arp", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0efa3a926..6515c31e5 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -278,7 +278,7 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" } -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ protocol: p, nic: nic, diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 94209b026..5fcbfeaa2 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,15 +15,14 @@ package arp_test import ( - "context" "fmt" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -31,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) const ( @@ -39,15 +37,6 @@ const ( stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // eventChanSize defines the size of event channels used by the neighbor - // cache's event dispatcher. The size chosen here needs to be sufficient to - // queue all the events received during tests before consumption. - // If eventChanSize is too small, the tests may deadlock. - eventChanSize = 32 ) var ( @@ -123,24 +112,6 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.C <- e } -func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { - select { - case got := <-d.C: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - case <-ctx.Done(): - return fmt.Errorf("%s for %s", ctx.Err(), want) - } - return nil -} - -func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.waitForEvent(ctx, want) -} - func (d *arpDispatcher) nextEvent() (eventInfo, bool) { select { case event := <-d.C: @@ -153,55 +124,45 @@ func (d *arpDispatcher) nextEvent() (eventInfo, bool) { type testContext struct { s *stack.Stack linkEP *channel.Endpoint - nudDisp *arpDispatcher + nudDisp arpDispatcher } -func newTestContext(t *testing.T) *testContext { - c := stack.DefaultNUDConfigurations() - // Transition from Reachable to Stale almost immediately to test if receiving - // probes refreshes positive reachability. - c.BaseReachableTime = time.Microsecond - - d := arpDispatcher{ - // Create an event channel large enough so the neighbor cache doesn't block - // while dispatching events. Blocking could interfere with the timing of - // NUD transitions. - C: make(chan eventInfo, eventChanSize), +func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext { + t.Helper() + + tc := testContext{ + nudDisp: arpDispatcher{ + C: make(chan eventInfo, eventDepth), + }, } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - NUDConfigs: c, - NUDDisp: &d, + tc.s = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + NUDDisp: &tc.nudDisp, + Clock: &faketime.NullClock{}, }) - ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - wep := stack.LinkEndpoint(ep) + tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr) + tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired + wep := stack.LinkEndpoint(tc.linkEP) if testing.Verbose() { - wep = sniffer.New(ep) + wep = sniffer.New(wep) } - if err := s.CreateNIC(nicID, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + if err := tc.s.CreateNIC(nicID, wep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) + if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %s", err) } - s.SetRouteTable([]tcpip.Route{{ + tc.s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: nicID, }}) - return &testContext{ - s: s, - linkEP: ep, - nudDisp: &d, - } + return tc } func (c *testContext) cleanup() { @@ -209,7 +170,7 @@ func (c *testContext) cleanup() { } func TestMalformedPacket(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() v := make(buffer.View, header.ARPSize) @@ -228,7 +189,7 @@ func TestMalformedPacket(t *testing.T) { } func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) @@ -253,7 +214,7 @@ func TestDisabledEndpoint(t *testing.T) { } func TestDirectReply(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -284,7 +245,7 @@ func TestDirectReply(t *testing.T) { } func TestDirectRequest(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 1, 1) defer c.cleanup() tests := []struct { @@ -391,17 +352,21 @@ func TestDirectRequest(t *testing.T) { } // Verify the sender was saved in the neighbor cache. - wantEvent := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: stack.NeighborEntry{ - Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), - State: stack.Stale, - }, - } - if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { - t.Fatal(err) + if got, ok := c.nudDisp.nextEvent(); ok { + want := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: test.senderLinkAddr, + State: stack.Stale, + }, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + t.Errorf("got invalid event (-want +got):\n%s", diff) + } + } else { + t.Fatal("event didn't arrive") } neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) @@ -589,7 +554,7 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -663,15 +628,16 @@ func TestLinkAddressRequest(t *testing.T) { } func TestDADARPRequestPacket(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, }, }), ipv4.NewProtocol}, + Clock: clock, }) - e := channel.New(1, defaultMTU, stackLinkAddr) + e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -682,7 +648,8 @@ func TestDADARPRequestPacket(t *testing.T) { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) } - pkt, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pkt, ok := e.Read() if !ok { t.Fatal("expected to send an ARP request") } diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go index 5168f5361..1ba4d0d36 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go @@ -251,7 +251,7 @@ func (f *Fragmentation) releaseReassemblersLocked() { // The list is empty. break } - elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + elapsed := now.Sub(r.createdAt) if f.timeout > elapsed { // If the oldest reassembler has not expired, schedule the release // job so that this function is called back when it has expired. diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 7daf64b4a..dadfc28cc 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) { highLimit := 3 * lowLimit // Allow at most 3 such packets. f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) + if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) + if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) + if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { + t.Fatal(err) + } if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { memSize := pkt(1, "0").MemSize() f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 56b76a284..5b7e4b361 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -35,21 +35,21 @@ type hole struct { type reassembler struct { reassemblerEntry - id FragmentID - memSize int - proto uint8 - mu sync.Mutex - holes []hole - filled int - done bool - creationTime int64 - pkt *stack.PacketBuffer + id FragmentID + memSize int + proto uint8 + mu sync.Mutex + holes []hole + filled int + done bool + createdAt tcpip.MonotonicTime + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ - id: id, - creationTime: clock.NowMonotonic(), + id: id, + createdAt: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index a22b712c6..24687cf06 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) } // Wait for any initially fired timers to complete. - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -347,7 +347,7 @@ func TestNonce(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() for i, want := range test.expectedResults { if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index d22974b12..671dfbf32 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -611,7 +611,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // If a timer for any address is already running, it is reset to the new // random value only if the requested Maximum Response Delay is less than // the remaining value of the running timer. - now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) + now := g.opts.Clock.Now() if info.state == delayingMember { if info.delayedReportJobFiresAt.IsZero() { panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 0c2b62127..40ab21cb6 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -42,6 +42,10 @@ type MultiCounterIPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig tcpip.MultiCounterStat + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable tcpip.MultiCounterStat + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -61,6 +65,7 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) + m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable) } // LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index cec3e62c4..a180e5c75 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip/network/internal/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", + "//pkg/tcpip/tests/integration:__pkg__", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index bd63e0289..771b9173a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -88,6 +88,7 @@ type testObject struct { dataCalls int controlCalls int + rawCalls int } // checkValues verifies that the transport protocol, data contents, src & dst @@ -148,6 +149,10 @@ func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpi t.controlCalls++ } +func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { + t.rawCalls++ +} + // Attach is only implemented to satisfy the LinkEndpoint interface. func (*testObject) Attach(stack.NetworkDispatcher) {} @@ -717,7 +722,10 @@ func TestReceive(t *testing.T) { } test.handlePacket(t, ep, &nic) if nic.testObject.dataCalls != 1 { - t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 1 { + t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) } if got := stat.Value(); got != 1 { t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) @@ -968,7 +976,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) + t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 0 { + t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls) } // Send second segment. @@ -977,7 +988,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { }) ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 1 { + t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) } } @@ -1310,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return hdr.View().ToVectorisedView() }, @@ -1351,7 +1365,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) ip.SetHeaderLength(header.IPv4MinimumSize - 1) return hdr.View().ToVectorisedView() @@ -1370,7 +1384,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, @@ -1388,7 +1402,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip).ToVectorisedView() }, @@ -1430,7 +1444,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, Options: ipv4Options, }) return hdr.View().ToVectorisedView() @@ -1469,7 +1483,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, Options: ipv4Options, }) vv := buffer.View(ip).ToVectorisedView() @@ -1515,7 +1529,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return hdr.View().ToVectorisedView() }, @@ -1560,7 +1574,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return hdr.View().ToVectorisedView() }, @@ -1595,7 +1609,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return buffer.View(ip).ToVectorisedView() }, @@ -1630,7 +1644,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d1a82b584..2aa38eb98 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() @@ -222,7 +221,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -481,6 +479,22 @@ func (*icmpReasonFragmentationNeeded) isForwarding() bool { return true } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 792 page 5, Destination Unreachable Message, + // + // In addition, in some networks, the gateway may be able to determine + // if the internet destination host is unreachable. Gateways in these + // networks may send destination unreachable messages to the source host + // when the destination host is unreachable. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -537,7 +551,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -653,6 +672,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4NetUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4HostUnreachable) + counter = sent.dstUnreachable case *icmpReasonFragmentationNeeded: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 23178277a..44c85bdb8 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -82,7 +82,7 @@ type endpoint struct { protocol *protocol stats sharedStats - // enabled is set to 1 when the enpoint is enabled and 0 when it is + // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. // // Must be accessed using atomic operations. @@ -104,6 +104,16 @@ type endpoint struct { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, return an ICMP error to the original + // packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -322,8 +332,8 @@ func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. +// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the +// network layer max header length. func (e *endpoint) MTU() uint32 { networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) if err != nil { @@ -338,7 +348,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } -// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +// NetworkProtocolNumber implements stack.NetworkEndpoint. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } @@ -353,7 +363,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet if hdrLen > header.IPv4MaximumHeaderSize { return &tcpip.ErrMessageTooLong{} } - ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) + ipH := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := pkt.Size() if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} @@ -362,7 +372,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) - ip.Encode(&header.IPv4Fields{ + ipH.Encode(&header.IPv4Fields{ TotalLength: uint16(length), ID: uint16(id), TTL: params.TTL, @@ -372,7 +382,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet DstAddr: dstAddr, Options: options, }) - ip.SetChecksum(^ip.CalculateChecksum()) + ipH.SetChecksum(^ipH.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber return nil } @@ -381,7 +391,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { +func (e *endpoint) handleFragments(_ *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -419,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -490,7 +500,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn return nil } -// WritePackets implements stack.NetworkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint. func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("multiple packets in local loop") @@ -593,34 +603,30 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return &tcpip.ErrMalformedHeader{} } - ip := header.IPv4(h) + ipH := header.IPv4(h) // Always set the total length. pktSize := pkt.Data().Size() - ip.SetTotalLength(uint16(pktSize)) + ipH.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == header.IPv4Any { - ip.SetSourceAddress(r.LocalAddress()) + if ipH.SourceAddress() == header.IPv4Any { + ipH.SetSourceAddress(r.LocalAddress()) } - // Set the destination. If the packet already included a destination, it will - // be part of the route anyways. - ip.SetDestinationAddress(r.RemoteAddress()) - // Set the packet ID when zero. - if ip.ID() == 0 { + if ipH.ID() == 0 { // RFC 6864 section 4.3 mandates uniqueness of ID values for // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. - if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + if ipH.Flags()&header.IPv4FlagDontFragment == 0 || ipH.Flags()&header.IPv4FlagMoreFragments != 0 || ipH.FragmentOffset() > 0 { + ipH.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } // Always set the checksum. - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) + ipH.SetChecksum(0) + ipH.SetChecksum(^ipH.CalculateChecksum()) // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. @@ -710,7 +716,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -826,7 +833,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -845,10 +852,17 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { + // Raw socket packets are delivered based solely on the transport protocol + // number. We only require that the packet be valid IPv4, and that they not + // be fragmented. + if !h.More() && h.FragmentOffset() == 0 { + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) + } + pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -898,7 +912,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) case *ip.ErrNoRoute: stats.ip.Forwarding.Unrouteable.Increment() case *ip.ErrParameterProblem: - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() case *ip.ErrMessageTooLong: stats.ip.Forwarding.PacketTooBig.Increment() @@ -911,8 +924,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -935,7 +947,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -985,8 +996,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // The reassembler doesn't take care of fixing up the header, so we need // to do it here. - h.SetTotalLength(uint16(pkt.Data().Size() + len((h)))) + h.SetTotalLength(uint16(pkt.Data().Size() + len(h))) h.SetFlagsFragmentOffset(0, 0) + + // Now that the packet is reassembled, it can be sent to raw sockets. + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } stats.ip.PacketsDelivered.Increment() @@ -1008,7 +1022,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() } return @@ -1217,13 +1230,13 @@ func (p *protocol) DefaultPrefixLen() int { return header.IPv4AddressSize * 8 } -// ParseAddresses implements NetworkProtocol.ParseAddresses. +// ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) return h.SourceAddress(), h.DestinationAddress() } -// SetOption implements NetworkProtocol.SetOption. +// SetOption implements stack.NetworkProtocol. func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -1234,7 +1247,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E } } -// Option implements NetworkProtocol.Option. +// Option implements stack.NetworkProtocol. func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -1255,10 +1268,10 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } -// Close implements stack.TransportProtocol.Close. +// Close implements stack.TransportProtocol. func (*protocol) Close() {} -// Wait implements stack.TransportProtocol.Wait. +// Wait implements stack.TransportProtocol. func (*protocol) Wait() {} // parseAndValidate parses the packet (including its transport layer header) and @@ -1296,7 +1309,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) return h, true } -// Parse implements stack.NetworkProtocol.Parse. +// Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { return 0, false, false @@ -1326,7 +1339,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error networkMTU = MaxTotalSize } - return networkMTU - uint32(networkHeaderSize), nil + return networkMTU - networkHeaderSize, nil } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool { @@ -1735,9 +1748,8 @@ type optionTracker struct { // // If there were no errors during parsing, the new set of options is returned as // a new buffer. -func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) { +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, opts header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) { stats := e.stats.ip - opts := header.IPv4Options(orig) optIter := opts.MakeIterator() // Except NOP, each option must only appear at most once (RFC 791 section 3.1, diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index da9cc0ae8..4a4448cf9 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,7 +16,6 @@ package ipv4_test import ( "bytes" - "context" "encoding/hex" "fmt" "io/ioutil" @@ -36,10 +35,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -112,10 +111,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -type forwardedPacket struct { - fragments []fragmentInfo -} - func TestForwarding(t *testing.T) { const ( incomingNICID = 1 @@ -134,11 +129,11 @@ func TestForwarding(t *testing.T) { PrefixLen: 8, } outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2") - remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2") - unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2") - multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0") - linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0") + remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") + multicastIPv4Addr := testutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") tests := []struct { name string @@ -402,13 +397,13 @@ func TestForwarding(t *testing.T) { totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) - icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ TotalLength: uint16(totalLength), @@ -453,7 +448,7 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(incomingIPv4Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), @@ -461,7 +456,7 @@ func TestForwarding(t *testing.T) { checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) } else if ok { @@ -470,7 +465,7 @@ func TestForwarding(t *testing.T) { if test.expectPacketForwarded { if len(test.expectedFragmentsForwarded) != 0 { - fragmentedPackets := []*stack.PacketBuffer{} + var fragmentedPackets []*stack.PacketBuffer for i := 0; i < len(test.expectedFragmentsForwarded); i++ { reply, ok = outgoingEndpoint.Read() if !ok { @@ -487,7 +482,7 @@ func TestForwarding(t *testing.T) { // maximum IP header size and the maximum size allocated for link layer // headers. In this case, no size is allocated for link layer headers. expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize - if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { t.Error(err) } } else { @@ -496,7 +491,7 @@ func TestForwarding(t *testing.T) { t.Fatal("expected ICMP Echo packet through outgoing NIC") } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), @@ -1212,15 +1207,15 @@ func TestIPv4Sanity(t *testing.T) { } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) // Specify ident/seq to make sure we get the same in the response. - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength @@ -1315,7 +1310,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1334,7 +1329,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1546,9 +1541,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -1602,7 +1597,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) for _, test := range writePacketsTests { t.Run(test.description, func(t *testing.T) { @@ -1612,13 +1607,13 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -1726,8 +1721,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2301,7 +2296,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv4Type(header.ICMPv4TimeExceeded), checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), + checker.ICMPv4Payload(firstFragmentSent), ), ) }) @@ -2705,7 +2700,7 @@ func TestReceiveFragments(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, RawFactory: raw.EndpointFactory{}, }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + e := channel.New(0, 1280, "\xf0\x00") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -2948,7 +2943,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -3035,7 +3030,7 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) return false, false } -func TestPacketQueing(t *testing.T) { +func TestPacketQueuing(t *testing.T) { const nicID = 1 var ( @@ -3074,7 +3069,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ @@ -3090,7 +3085,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3133,7 +3128,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3157,9 +3152,11 @@ func TestPacketQueing(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -3182,7 +3179,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a ARP request since link address resolution should be // performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3223,6 +3221,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed @@ -3244,8 +3243,8 @@ func TestCloseLocking(t *testing.T) { ) var ( - src = tcptestutil.MustParse4("16.0.0.1") - dst = tcptestutil.MustParse4("16.0.0.2") + src = testutil.MustParse4("16.0.0.1") + dst = testutil.MustParse4("16.0.0.2") ) s := stack.New(stack.Options{ @@ -3253,7 +3252,7 @@ func TestCloseLocking(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - // Perform NAT so that the endoint tries to search for a sibling endpoint + // Perform NAT so that the endpoint tries to search for a sibling endpoint // which ends up taking the protocol and endpoint lock (in that order). table := stack.Table{ Rules: []stack.Rule{ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 307e1972d..94caaae6c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) { sent := e.stats.icmp.packetsSent received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set. See icmp/protocol.go:protocol.Parse for a full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) if !ok { received.invalid.Increment() @@ -1029,6 +1029,26 @@ func (*icmpReasonNetUnreachable) respondsToMulticast() bool { return false } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 4443 page 8, Destination Unreachable Message, + // + // If the reason for the failure to deliver cannot be mapped to any of + // other codes, the Code field is set to 3. Example of such cases are + // an inability to resolve the IPv6 destination address into a + // corresponding link address, or a link-specific problem of some sort. + return true +} + +func (*icmpReasonHostUnreachable) respondsToMulticast() bool { + return false +} + // icmpReasonFragmentationNeeded is an error where a packet is to big to be sent // out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. type icmpReasonPacketTooBig struct{} @@ -1143,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -1222,6 +1247,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6AddressUnreachable) + counter = sent.dstUnreachable case *icmpReasonPacketTooBig: icmpHdr.SetType(header.ICMPv6PacketTooBig) icmpHdr.SetCode(header.ICMPv6UnusedCode) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 040cd4bc8..7c2a3e56b 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,17 +16,16 @@ package ipv6 import ( "bytes" - "context" "net" "reflect" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -46,16 +45,12 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second - arbitraryHopLimit = 42 ) var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -95,6 +90,10 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } +func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { + // No-op. +} + var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { @@ -371,6 +370,8 @@ type testContext struct { linkEP0 *channel.Endpoint linkEP1 *channel.Endpoint + + clock *faketime.ManualClock } type endpointWithResolutionCapability struct { @@ -382,15 +383,19 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab } func newTestContext(t *testing.T) *testContext { + clock := faketime.NewManualClock() c := &testContext{ s0: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), s1: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), + clock: clock, } c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) @@ -457,10 +462,14 @@ type routeArgs struct { remoteLinkAddr tcpip.LinkAddress } -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { +func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi, _ := args.src.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pi, ok := args.src.Read() + if !ok { + t.Fatal("packet didn't arrive") + } { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -533,7 +542,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) { if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) } @@ -544,7 +553,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, } { - routeICMPv6Packet(t, args, nil) + routeICMPv6Packet(t, c.clock, args, nil) } } @@ -1309,7 +1318,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1325,7 +1334,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1371,7 +1380,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1396,9 +1405,11 @@ func TestPacketQueing(t *testing.T) { e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -1421,7 +1432,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a neighbor solicitation since link address resolution should // be performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1475,6 +1487,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 95e11ac51..b1aec5312 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -184,7 +184,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - stack *stack.Stack stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is @@ -231,11 +230,11 @@ type endpoint struct { // If the NIC was created with a name, it is passed to NICNameFromID. // // NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are -// generated for the same prefix on differnt NICs. +// generated for the same prefix on different NICs. type NICNameFromID func(tcpip.NICID, string) string // OpaqueInterfaceIdentifierOptions holds the options related to the generation -// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +// of opaque interface identifiers (IIDs) as defined by RFC 7217. type OpaqueInterfaceIdentifierOptions struct { // NICNameFromID is a function that returns a stable name for a specified NIC, // even if the NIC ID changes over time. @@ -282,6 +281,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, we should return an ICMP error to the + // original packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -335,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() defer e.mu.Unlock() - e.mu.ndp.invalidateDefaultRouter(rtr) + + // We represent default routers with a default (off-link) route through the + // router. + e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr}) } // SetNDPConfigurations implements NDPEndpoint. @@ -536,7 +548,7 @@ func (e *endpoint) Enable() tcpip.Error { // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // - // Addresses may have aleady completed DAD but in the time since the endpoint + // Addresses may have already completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. var err tcpip.Error e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { @@ -641,8 +653,8 @@ func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. +// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the +// network layer max header length. func (e *endpoint) MTU() uint32 { networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) if err != nil { @@ -666,8 +678,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} } - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) - ip.Encode(&header.IPv6Fields{ + header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)).Encode(&header.IPv6Fields{ PayloadLength: uint16(length), TransportProtocol: params.Protocol, HopLimit: params.TTL, @@ -747,9 +758,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -818,7 +829,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol return nil } -// WritePackets implements stack.NetworkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint. func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("not implemented") @@ -909,21 +920,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return &tcpip.ErrMalformedHeader{} } - ip := header.IPv6(h) + ipH := header.IPv6(h) // Always set the payload length. pktSize := pkt.Data().Size() - ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + ipH.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) // Set the source address when zero. - if ip.SourceAddress() == header.IPv6Any { - ip.SetSourceAddress(r.LocalAddress()) + if ipH.SourceAddress() == header.IPv6Any { + ipH.SetSourceAddress(r.LocalAddress()) } - // Set the destination. If the packet already included a destination, it will - // be part of the route anyways. - ip.SetDestinationAddress(r.RemoteAddress()) - // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. // @@ -983,7 +990,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -1096,7 +1104,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -1115,10 +1123,14 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { + // Raw socket packets are delivered based solely on the transport protocol + // number. We only require that the packet be valid IPv6. + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) + pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1167,8 +1179,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1610,7 +1621,7 @@ func (e *endpoint) Close() { e.protocol.forgetEndpoint(e.nic.ID()) } -// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +// NetworkProtocolNumber implements stack.NetworkEndpoint. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } @@ -1619,8 +1630,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1661,8 +1672,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { @@ -2004,7 +2015,7 @@ func (p *protocol) DefaultPrefixLen() int { return header.IPv6AddressSize * 8 } -// ParseAddresses implements NetworkProtocol.ParseAddresses. +// ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) return h.SourceAddress(), h.DestinationAddress() @@ -2081,7 +2092,7 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// SetOption implements NetworkProtocol.SetOption. +// SetOption implements stack.NetworkProtocol. func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -2092,7 +2103,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E } } -// Option implements NetworkProtocol.Option. +// Option implements stack.NetworkProtocol. func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -2113,10 +2124,10 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } -// Close implements stack.TransportProtocol.Close. +// Close implements stack.TransportProtocol. func (*protocol) Close() {} -// Wait implements stack.TransportProtocol.Wait. +// Wait implements stack.TransportProtocol. func (*protocol) Wait() {} // parseAndValidate parses the packet (including its transport layer header) and @@ -2150,7 +2161,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) return h, true } -// Parse implements stack.NetworkProtocol.Parse. +// Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { @@ -2180,7 +2191,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error return 0, &tcpip.ErrMalformedHeader{} } - networkMTU := linkMTU - uint32(networkHeadersLen) + networkMTU := linkMTU - networkHeadersLen if networkMTU > maxPayloadSize { networkMTU = maxPayloadSize } @@ -2199,7 +2210,7 @@ type Options struct { // Note, setting this to true does not mean that a link-local address is // assigned right away, or at all. If Duplicate Address Detection is enabled, // an address is only assigned if it successfully resolves. If it fails, no - // further attempts are made to auto-generate a link-local adddress. + // further attempts are made to auto-generate a link-local address. // // The generated link-local address follows RFC 4291 Appendix A guidelines. AutoGenLinkLocal bool @@ -2215,7 +2226,7 @@ type Options struct { // TempIIDSeed is used to seed the initial temporary interface identifier // history value used to generate IIDs for temporary SLAAC addresses. // - // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // Temporary SLAAC addresses are short-lived addresses which are unpredictable // and random from the perspective of other nodes on the network. It is // recommended that the seed be a random byte buffer of at least // header.IIDSize bytes to make sure that temporary SLAAC addresses are diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index afc6c3547..d2a23fd4f 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -129,7 +129,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() @@ -402,7 +402,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }{ { name: "None", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, + extHdr: func(nextHdr uint8) ([]byte, uint8) { return nil, nextHdr }, shouldAccept: true, expectICMP: false, }, @@ -612,8 +612,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { { name: "No next header", extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{}, - noNextHdrID + return nil, noNextHdrID }, shouldAccept: false, expectICMP: false, @@ -1160,7 +1159,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), []buffer.View{ // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, ipv6Payload1Addr1ToAddr2, }, @@ -1180,7 +1179,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), []buffer.View{ // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, ipv6Payload3Addr1ToAddr2, }, @@ -1202,7 +1201,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1218,7 +1217,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1240,7 +1239,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1256,7 +1255,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1278,7 +1277,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1296,7 +1295,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 8, More = false, ID = 1 // NextHeader value is different than the one in the first fragment, so // this NextHeader should be ignored. - buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1318,7 +1317,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[:64], }, @@ -1334,7 +1333,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[64:], }, @@ -1356,7 +1355,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[:63], }, @@ -1372,7 +1371,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[63:], }, @@ -1394,7 +1393,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1410,7 +1409,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1432,7 +1431,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, @@ -1448,10 +1447,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + []byte{uint8(header.UDPProtocolNumber), 0, udpMaximumSizeMinus15 >> 8, udpMaximumSizeMinus15 & 0xff, - 0, 0, 0, 1}), + 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, @@ -1473,7 +1472,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, @@ -1489,10 +1488,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + []byte{uint8(header.UDPProtocolNumber), 0, udpMaximumSizeMinus15 >> 8, (udpMaximumSizeMinus15 & 0xff) + 1, - 0, 0, 0, 1}), + 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, @@ -1514,12 +1513,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1535,12 +1534,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1562,12 +1561,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1583,12 +1582,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1610,12 +1609,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header. // // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1631,7 +1630,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1653,12 +1652,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header. // // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1674,7 +1673,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1699,12 +1698,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header (part 1) // // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}, }, ), }, @@ -1721,10 +1720,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + []byte{6, 7, 8, 9, 10, 11, 12, 13}, ipv6Payload1Addr1ToAddr2, }, @@ -1749,12 +1748,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header (part 1) // // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}, }, ), }, @@ -1771,10 +1770,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + []byte{6, 7, 8, 9, 10, 11, 12, 13}, ipv6Payload1Addr1ToAddr2, }, @@ -1798,7 +1797,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1816,7 +1815,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}, ipv6Payload2Addr1ToAddr2, }, @@ -1832,7 +1831,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1854,7 +1853,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1870,7 +1869,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}, ipv6Payload2Addr1ToAddr2[:32], }, @@ -1886,7 +1885,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1902,7 +1901,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 4, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}, ipv6Payload2Addr1ToAddr2[32:], }, @@ -1924,7 +1923,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1940,7 +1939,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr3ToAddr2[:32], }, @@ -1956,7 +1955,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1972,7 +1971,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 4, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}, ipv6Payload1Addr3ToAddr2[32:], }, @@ -2208,7 +2207,7 @@ func TestInvalidIPv6Fragments(t *testing.T) { checker.ICMPv6Type(test.expectICMPType), checker.ICMPv6Code(test.expectICMPCode), checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), - checker.ICMPv6Payload([]byte(expectICMPPayload)), + checker.ICMPv6Payload(expectICMPPayload), ), ) }) @@ -2459,7 +2458,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6TimeExceeded), checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), - checker.ICMPv6Payload([]byte(firstFragmentSent)), + checker.ICMPv6Payload(firstFragmentSent), ), ) }) @@ -2795,11 +2794,7 @@ var fragmentationTests = []struct { } func TestFragmentationWritePacket(t *testing.T) { - const ( - ttl = 42 - tos = stack.DefaultTOS - transportProto = tcp.ProtocolNumber - ) + const ttl = 42 for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { @@ -3335,7 +3330,7 @@ func TestForwarding(t *testing.T) { } transportProtocol := header.ICMPv6ProtocolNumber - extHdrBytes := []byte{} + var extHdrBytes []byte extHdrChecker := checker.IPv6ExtHdr() if test.extHdr != nil { nextHdrID := hopByHopExtHdrID @@ -3349,15 +3344,15 @@ func TestForwarding(t *testing.T) { totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) - icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) - - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv6EchoRequest) - icmp.SetCode(header.ICMPv6UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, + icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) + + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, Src: test.sourceAddr, Dst: test.destAddr, })) @@ -3395,14 +3390,14 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(incomingIPv6Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( checker.ICMPv6Type(test.icmpType), checker.ICMPv6Code(test.icmpCode), - checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv6Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) @@ -3419,7 +3414,7 @@ func TestForwarding(t *testing.T) { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } - checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv6WithExtHdr(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 71d1c3e28..bc9cf6999 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,6 +16,7 @@ package ipv6_test import ( "bytes" + "math/rand" "testing" "time" @@ -138,7 +139,8 @@ func TestSendQueuedMLDReports(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, + SecureRNG: &secureRNG, + RandSource: rand.NewSource(time.Now().UnixNano()), NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index f0ff111c5..9cd283eba 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -16,7 +16,6 @@ package ipv6 import ( "fmt" - "math/rand" "time" "gvisor.dev/gvisor/pkg/sync" @@ -79,13 +78,13 @@ const ( // we cannot have a negative delay. minimumMaxRtrSolicitationDelay = 0 - // MaxDiscoveredDefaultRouters is the maximum number of discovered - // default routers. The stack should stop discovering new routers after - // discovering MaxDiscoveredDefaultRouters routers. + // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link + // routes. The stack should stop discovering new off-link routes after + // this limit is reached. // // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and // SHOULD be more. - MaxDiscoveredDefaultRouters = 10 + MaxDiscoveredOffLinkRoutes = 10 // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered // on-link prefixes. The stack should stop discovering new on-link @@ -128,25 +127,17 @@ const ( // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt // SLAAC address regenerations in response to an IPv6 endpoint-local conflict. maxSLAACAddrLocalRegenAttempts = 10 -) -var ( // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid // Lifetime to update the valid lifetime of a generated address by // SLAAC. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // Min = 2hrs. MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour // MaxDesyncFactor is the upper bound for the preferred lifetime's desync // factor for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // Must be greater than 0. // // Max = 10m (from RFC 4941 section 5). @@ -155,9 +146,6 @@ var ( // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the // maximum preferred lifetime for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // This value guarantees that a temporary address is preferred for at // least 1hr if the SLAAC prefix is valid for at least that time. MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour @@ -165,9 +153,6 @@ var ( // MinMaxTempAddrValidLifetime is the minimum value allowed for the // maximum valid lifetime for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // This value guarantees that a temporary address is valid for at least // 2hrs if the SLAAC prefix is valid for at least that time. MinMaxTempAddrValidLifetime = 2 * time.Hour @@ -215,28 +200,23 @@ type NDPDispatcher interface { // is also not permitted to call into the stack. OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) - // OnDefaultRouterDiscovered is called when a new default router is - // discovered. Implementations must return true if the newly discovered - // router should be remembered. + // OnOffLinkRouteUpdated is called when an off-link route is updated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool + OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) - // OnDefaultRouterInvalidated is called when a discovered default router that - // was remembered is invalidated. + // OnOffLinkRouteInvalidated is called when an off-link route is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) + OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. - // Implementations must return true if the newly discovered on-link prefix - // should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool + OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that // was remembered is invalidated. @@ -246,13 +226,11 @@ type NDPDispatcher interface { OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) // OnAutoGenAddress is called when a new prefix with its autonomous address- - // configuration flag set is received and SLAAC was performed. Implementations - // may prevent the stack from assigning the address to the NIC by returning - // false. + // configuration flag set is received and SLAAC was performed. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. - OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) // is deprecated, but is still considered valid. Note, if an address is @@ -470,6 +448,11 @@ type timer struct { timer tcpip.Timer } +type offLinkRoute struct { + dest tcpip.Subnet + router tcpip.Address +} + // ndpState is the per-Interface NDP state. type ndpState struct { // Do not allow overwriting this state. @@ -484,8 +467,8 @@ type ndpState struct { // The DAD timers to send the next NS message, or resolve the address. dad ip.DAD - // The default routers discovered through Router Advertisements. - defaultRouters map[tcpip.Address]defaultRouterState + // The off-link routes discovered through Router Advertisements. + offLinkRoutes map[offLinkRoute]offLinkRouteState // rtrSolicitTimer is the timer used to send the next router solicitation // message. @@ -513,10 +496,12 @@ type ndpState struct { temporaryAddressDesyncFactor time.Duration } -// defaultRouterState holds data associated with a default router discovered by +// offLinkRouteState holds data associated with an off-link route discovered by // a Router Advertisement (RA). -type defaultRouterState struct { - // Job to invalidate the default router. +type offLinkRouteState struct { + prf header.NDPRoutePreference + + // Job to invalidate the route. // // Must not be nil. invalidationJob *tcpip.Job @@ -549,7 +534,7 @@ type tempSLAACAddrState struct { // Must not be nil. regenJob *tcpip.Job - createdAt time.Time + createdAt tcpip.MonotonicTime // The address's endpoint. // @@ -572,11 +557,11 @@ type slaacPrefixState struct { // Must not be nil. invalidationJob *tcpip.Job - // Nonzero only when the address is not valid forever. - validUntil time.Time + // nil iff the address is valid forever. + validUntil *tcpip.MonotonicTime - // Nonzero only when the address is not preferred forever. - preferredUntil time.Time + // nil iff the address is preferred forever. + preferredUntil *tcpip.MonotonicTime // State associated with the stable address generated for the prefix. stableAddr struct { @@ -734,30 +719,22 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Is the IPv6 endpoint configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { - rtr, ok := ndp.defaultRouters[ip] - rl := ra.RouterLifetime() - switch { - case !ok && rl != 0: - // This is a new default router we are discovering. + prf := ra.DefaultRouterPreference() + if prf == header.ReservedRoutePreference { + // As per RFC 4191 section 2.2, // - // Only remember it if we currently know about less than - // MaxDiscoveredDefaultRouters routers. - if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { - ndp.rememberDefaultRouter(ip, rl) - } - - case ok && rl != 0: - // This is an already discovered default router. Update - // the invalidation job. - rtr.invalidationJob.Cancel() - rtr.invalidationJob.Schedule(rl) - ndp.defaultRouters[ip] = rtr - - case ok && rl == 0: - // We know about the router but it is no longer to be - // used as a default router so invalidate it. - ndp.invalidateDefaultRouter(ip) + // Prf (Default Router Preference) + // + // If the Reserved (10) value is received, the receiver MUST treat the + // value as if it were (00). + // + // Note that the value 00 is the medium (default) router preference value. + prf = header.MediumRoutePreference } + + // We represent default routers with a default (off-link) route through the + // router. + ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf) } // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter @@ -815,55 +792,75 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { } } -// invalidateDefaultRouter invalidates a discovered default router. +// invalidateOffLinkRoute invalidates a discovered off-link route. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { - rtr, ok := ndp.defaultRouters[ip] - - // Is the router still discovered? +func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) { + state, ok := ndp.offLinkRoutes[route] if !ok { - // ...Nope, do nothing further. return } - rtr.invalidationJob.Cancel() - delete(ndp.defaultRouters, ip) + state.invalidationJob.Cancel() + delete(ndp.offLinkRoutes, route) - // Let the integrator know a discovered default router is invalidated. + // Let the integrator know a discovered off-link route is invalidated. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) + ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router) } } -// rememberDefaultRouter remembers a newly discovered default router with IPv6 -// link-local address ip with lifetime rl. -// -// The router identified by ip MUST NOT already be known by the IPv6 endpoint. +// handleOffLinkRouteDiscovery handles the discovery of an off-link route. // -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) { ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } - // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { - // Informed by the integrator to not remember the router, do - // nothing further. - return - } + state, ok := ndp.offLinkRoutes[route] + switch { + case !ok && lifetime != 0: + // This is a new route we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredOffLinkRoutes routers. + if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes { + // Inform the integrator when we discovered an off-link route. + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf) + + state := offLinkRouteState{ + prf: prf, + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + ndp.invalidateOffLinkRoute(route) + }), + } - state := defaultRouterState{ - invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { - ndp.invalidateDefaultRouter(ip) - }), - } + state.invalidationJob.Schedule(lifetime) + + ndp.offLinkRoutes[route] = state + } + + case ok && lifetime != 0: + // This is an already discovered off-link route. Update the lifetime. + state.invalidationJob.Cancel() + state.invalidationJob.Schedule(lifetime) + + if prf != state.prf { + state.prf = prf - state.invalidationJob.Schedule(rl) + // Inform the integrator about route preference updates. + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf) + } + + ndp.offLinkRoutes[route] = state - ndp.defaultRouters[ip] = state + case ok && lifetime == 0: + // The already discovered off-link route is no longer considered valid so we + // invalidate it immediately. + ndp.invalidateOffLinkRoute(route) + } } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -879,11 +876,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { - // Informed by the integrator to not remember the prefix, do - // nothing further. - return - } + ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) state := onLinkPrefixState{ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { @@ -1051,12 +1044,13 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, } - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // The time an address is preferred until is needed to properly generate the // address. if pl < header.NDPInfiniteLifetime { - state.preferredUntil = now.Add(pl) + t := now.Add(pl) + state.preferredUntil = &t } if !ndp.generateSLAACAddr(prefix, &state) { @@ -1074,7 +1068,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if vl < header.NDPInfiniteLifetime { state.invalidationJob.Schedule(vl) - state.validUntil = now.Add(vl) + t := now.Add(vl) + state.validUntil = &t } // If the address is assigned (DAD resolved), generate a temporary address. @@ -1097,16 +1092,13 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { - // Informed by the integrator not to add the address. - return nil - } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } + ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) + return addressEndpoint } @@ -1182,7 +1174,8 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt state.stableAddr.localGenerationFailures++ } - if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + deprecated := state.preferredUntil != nil && !state.preferredUntil.After(ndp.ep.protocol.stack.Clock().NowMonotonic()) + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, deprecated); addressEndpoint != nil { state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true @@ -1237,13 +1230,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary // address is the lower of the valid lifetime of the stable address or the // maximum temporary address valid lifetime. vl := ndp.configs.MaxTempAddrValidLifetime - if prefixState.validUntil != (time.Time{}) { + if prefixState.validUntil != nil { if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { vl = prefixVL } @@ -1259,7 +1252,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // maximum temporary address preferred lifetime - the temporary address desync // factor. pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor - if prefixState.preferredUntil != (time.Time{}) { + if prefixState.preferredUntil != nil { if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { // Respect the preferred lifetime of the prefix, as per RFC 4941 section // 3.3 step 4. @@ -1394,16 +1387,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // deprecation job so it can be reset. prefixState.deprecationJob.Cancel() - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { prefixState.deprecationJob.Schedule(pl) } - prefixState.preferredUntil = now.Add(pl) + t := now.Add(pl) + prefixState.preferredUntil = &t } else { - prefixState.preferredUntil = time.Time{} + prefixState.preferredUntil = nil } // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: @@ -1421,17 +1415,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // Handle the infinite valid lifetime separately as we do not schedule a // job in this case. prefixState.invalidationJob.Cancel() - prefixState.validUntil = time.Time{} + prefixState.validUntil = nil } else { var effectiveVl time.Duration var rl time.Duration // If the prefix was originally set to be valid forever, assume the // remaining time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { + if prefixState.validUntil == nil { rl = header.NDPInfiniteLifetime } else { - rl = time.Until(prefixState.validUntil) + rl = prefixState.validUntil.Sub(now) } if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { @@ -1443,7 +1437,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat if effectiveVl != 0 { prefixState.invalidationJob.Cancel() prefixState.invalidationJob.Schedule(effectiveVl) - prefixState.validUntil = now.Add(effectiveVl) + t := now.Add(effectiveVl) + prefixState.validUntil = &t } } @@ -1463,8 +1458,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // maximum temporary address valid lifetime. Note, the valid lifetime of a // temporary address is relative to the address's creation time. validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) - if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { - validUntil = prefixState.validUntil + if prefixState.validUntil != nil && prefixState.validUntil.Before(validUntil) { + validUntil = *prefixState.validUntil } // If the address is no longer valid, invalidate it immediately. Otherwise, @@ -1483,14 +1478,15 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // desync factor. Note, the preferred lifetime of a temporary address is // relative to the address's creation time. preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) - if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { - preferredUntil = prefixState.preferredUntil + if prefixState.preferredUntil != nil && prefixState.preferredUntil.Before(preferredUntil) { + preferredUntil = *prefixState.preferredUntil } // If the address is no longer preferred, deprecate it immediately. // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) tempAddrState.deprecationJob.Cancel() + if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) } else { @@ -1680,12 +1676,12 @@ func (ndp *ndpState) cleanupState() { panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) } - for router := range ndp.defaultRouters { - ndp.invalidateDefaultRouter(router) + for route := range ndp.offLinkRoutes { + ndp.invalidateOffLinkRoute(route) } - if got := len(ndp.defaultRouters); got != 0 { - panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) + if got := len(ndp.offLinkRoutes); got != 0 { + panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got)) } ndp.dhcpv6Configuration = 0 @@ -1718,7 +1714,7 @@ func (ndp *ndpState) startSolicitingRouters() { // 4861 section 6.3.7. var delay time.Duration if ndp.configs.MaxRtrSolicitationDelay > 0 { - delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + delay = time.Duration(ndp.ep.protocol.stack.Rand().Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } // Protected by ndp.ep.mu. @@ -1754,7 +1750,7 @@ func (ndp *ndpState) startSolicitingRouters() { header.NDPSourceLinkLayerAddressOption(linkAddress), } } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length() icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) @@ -1848,21 +1844,19 @@ func (ndp *ndpState) stopSolicitingRouters() { } func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { - if ndp.defaultRouters != nil { + if ndp.offLinkRoutes != nil { panic("attempted to initialize NDP state twice") } ndp.ep = ep ndp.configs = ep.protocol.options.NDPConfigs ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions) - ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState) ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) - if MaxDesyncFactor != 0 { - ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) - } + ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor))) } func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 234e34952..f0186c64e 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -16,7 +16,7 @@ package ipv6 import ( "bytes" - "context" + "math/rand" "strings" "testing" "time" @@ -32,58 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - var _ NDPDispatcher = (*testNDPDispatcher)(nil) // testNDPDispatcher is an NDPDispatcher only allows default router discovery. @@ -94,24 +42,21 @@ type testNDPDispatcher struct { func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { +func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) { t.addr = addr - return true } -func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { +func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) { t.addr = addr } -func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { } -func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return false +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { @@ -148,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { ipv6EP := ep.(*endpoint) ipv6EP.mu.Lock() - ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) + ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference) ipv6EP.mu.Unlock() if ndpDisp.addr != lladdr1 { @@ -163,11 +108,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - // TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -456,8 +396,10 @@ func TestNeighborSolicitationResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + Clock: clock, }) e := channel.New(1, 1280, nicLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -527,7 +469,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { } if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NS response") } @@ -582,7 +525,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { })) } - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NA response") } @@ -906,12 +850,12 @@ func TestNDPValidation(t *testing.T) { routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp[:typ.size], + icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmpH[typ.size:], typ.extraData) + icmpH.SetType(typ.typ) + icmpH.SetCode(test.code) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH[:typ.size], Src: lladdr0, Dst: lladdr1, PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), @@ -937,7 +881,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) + handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -1297,8 +1241,9 @@ func TestCheckDuplicateAddress(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, - Clock: clock, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ DADConfigs: dadConfigs, })}, @@ -1315,7 +1260,8 @@ func TestCheckDuplicateAddress(t *testing.T) { snmc := header.SolicitedNodeAddr(lladdr0) remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) checkDADMsg := func() { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("expected %d-th DAD message", dadPacketsSent) } @@ -1363,7 +1309,7 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err) } // Should not restart DAD since we already requested DAD above - the handler - // should be called when the original request compeletes so we should not send + // should be called when the original request completes so we should not send // an extra DAD message here. dadRequestsMade++ if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index b5b013b64..854d6a6ba 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) { // Wildcard destinations affect all destinations for TupleOnly. if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress { // Only bitwise and the TupleOnlyFlag. - intersection &= ((^TupleOnlyFlag) | counter.SharedFlags()) + intersection &= (^TupleOnlyFlag) | counter.SharedFlags() count++ } } @@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error) // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) { +func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) { pm.ephemeralMu.RLock() firstEphemeral := pm.firstEphemeral numEphemeral := pm.numEphemeral pm.ephemeralMu.RUnlock() - offset := uint32(rand.Int31n(int32(numEphemeral))) + offset := uint32(rng.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } @@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) // An optional PortTester can be passed in which if provided will be used to // test if the picked port can be used. The function should return true if the // port is safe to use, false otherwise. -func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { +func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv } // A port wasn't specified, so try to find one. - return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) { res.Port = p if !pm.reserveSpecificPortLocked(res) { return false, nil diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 6c4fb8c68..a91b130df 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -18,6 +18,7 @@ import ( "math" "math/rand" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -331,6 +332,7 @@ func TestPortReservation(t *testing.T) { t.Run(test.tname, func(t *testing.T) { pm := NewPortManager() net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} + rng := rand.New(rand.NewSource(time.Now().UnixNano())) for _, test := range test.actions { first, _ := pm.PortRange() @@ -356,7 +358,7 @@ func TestPortReservation(t *testing.T) { BindToDevice: test.device, Dest: test.dest, } - gotPort, err := pm.ReservePort(portRes, nil /* testPort */) + gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */) if diff := cmp.Diff(test.want, err); diff != "" { t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) } @@ -417,10 +419,11 @@ func TestPickEphemeralPort(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { t.Fatalf("failed to set ephemeral port range: %s", err) } - port, err := pm.PickEphemeralPort(test.f) + port, err := pm.PickEphemeralPort(rng, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index b26936b7f..0ea85f9ed 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -222,7 +222,7 @@ type SocketOptions struct { getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` // receiveBufferSize determines the receive buffer size for this socket. - receiveBufferSize int64 + receiveBufferSize atomicbitops.AlignedAtomicInt64 // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -601,9 +601,10 @@ func (so *SocketOptions) GetBindToDevice() int32 { return atomic.LoadInt32(&so.bindToDevice) } -// SetBindToDevice sets value for SO_BINDTODEVICE option. +// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is +// zero, the socket device binding is removed. func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { - if !so.handler.HasNIC(bindToDevice) { + if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) { return &ErrUnknownDevice{} } @@ -653,13 +654,13 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { // GetReceiveBufferSize gets value for SO_RCVBUF option. func (so *SocketOptions) GetReceiveBufferSize() int64 { - return atomic.LoadInt64(&so.receiveBufferSize) + return so.receiveBufferSize.Load() } // SetReceiveBufferSize sets value for SO_RCVBUF option. func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { if !notify { - atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + so.receiveBufferSize.Store(receiveBufferSize) return } @@ -684,8 +685,8 @@ func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bo v = math.MaxInt32 } - oldSz := atomic.LoadInt64(&so.receiveBufferSize) + oldSz := so.receiveBufferSize.Load() // Notify endpoint about change in buffer size. newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) - atomic.StoreInt64(&so.receiveBufferSize, newSz) + so.receiveBufferSize.Store(newSz) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 84aa6a9e4..e0847e58a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", @@ -94,7 +95,7 @@ go_library( go_test( name = "stack_x_test", - size = "medium", + size = "small", srcs = [ "addressable_endpoint_state_test.go", "ndp_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5720e7543..782e74b24 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -35,7 +35,6 @@ import ( // Currently, only TCP tracking is supported. // Our hash table has 16K buckets. -// TODO(gvisor.dev/issue/170): These should be tunable. const numBuckets = 1 << 14 // Direction of the tuple. @@ -125,6 +124,8 @@ type conn struct { tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. + // + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. lastUsed time.Time `state:".(unixTime)"` } @@ -163,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. - // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle - // other tcp states. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) } else if hook == cn.tcbHook { @@ -244,8 +243,7 @@ func (ct *ConnTrack) init() { // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support -// other transport protocols. +// TODO(gvisor.dev/issue/6168): Support UDP. func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { tid, err := packetToTupleID(pkt) if err != nil { @@ -383,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/170): Support other transport protocols. + // TODO(gvisor.dev/issue/6168): Support UDP. if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -407,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. + var newAddr tcpip.Address + var newPort uint16 + + updateSRCFields := false + switch hook { case Prerouting, Output: if conn.manip == manipDestination { switch dir { case dirOriginal: - tcpHeader.SetDestinationPort(conn.reply.srcPort) - netHeader.SetDestinationAddress(conn.reply.srcAddr) + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr case dirReply: - tcpHeader.SetSourcePort(conn.original.dstPort) - netHeader.SetSourceAddress(conn.original.dstAddr) + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + + updateSRCFields = true } pkt.NatDone = true } @@ -424,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if conn.manip == manipSource { switch dir { case dirOriginal: - tcpHeader.SetSourcePort(conn.reply.dstPort) - netHeader.SetSourceAddress(conn.reply.dstAddr) + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + + updateSRCFields = true case dirReply: - tcpHeader.SetDestinationPort(conn.original.srcPort) - netHeader.SetDestinationAddress(conn.original.srcAddr) + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr } pkt.NatDone = true } @@ -439,33 +446,33 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } + fullChecksum := false + updatePseudoHeader := false switch hook { case Prerouting, Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - tcpHeader.SetChecksum(xsum) + updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + fullChecksum = true + updatePseudoHeader = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } + rewritePacket( + netHeader, + tcpHeader, + updateSRCFields, + fullChecksum, + updatePseudoHeader, + newPort, + newAddr, + ) // Update the state of tcb. - // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle - // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() @@ -542,8 +549,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused returns the next bucket that should be checked and the time after // which it should be called again. func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { - // TODO(gvisor.dev/issue/170): This can be more finely controlled, as - // it is in Linux via sysctl. const fractionPerReaping = 128 const maxExpiredPct = 50 const maxFullTraversal = 60 * time.Second diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7107d598d..72f66441f 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -114,10 +115,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -134,7 +131,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -224,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { - time.AfterFunc(f.proto.addrResolveDelay, func() { + f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() { fn(f.proto.neigh, addr, remoteLinkAddr) }) } @@ -319,7 +316,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -354,17 +351,19 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) { + clock := faketime.NewManualClock() // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, + Clock: clock, }) protoNum := proto.Number() @@ -373,7 +372,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ + ep1 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "a", @@ -386,7 +385,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ + ep2 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "b", @@ -416,7 +415,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) } - return ep1, ep2 + return clock, ep1, ep2 } func TestForwardingWithStaticResolver(t *testing.T) { @@ -432,7 +431,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -444,6 +443,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: default: @@ -475,7 +475,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -487,9 +487,10 @@ func TestForwardingWithFakeResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -508,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -518,10 +519,11 @@ func TestForwardingWithNoResolver(t *testing.T) { Data: buf.ToVectorisedView(), })) + clock.Advance(proto.addrResolveDelay) select { case <-ep2.C: t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): + default: } } @@ -533,7 +535,7 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -547,12 +549,12 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } // All packets should fail resolution. - // TODO(gvisor.dev/issue/5141): Use a fake clock. for i := 0; i < numPackets; i++ { + clock.Advance(proto.addrResolveDelay) select { case got := <-ep2.C: t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - case <-time.After(100 * time.Millisecond): + default: } } } @@ -576,7 +578,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { } }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -596,9 +598,10 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -631,7 +634,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -645,9 +648,10 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -681,7 +685,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -697,9 +701,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -745,7 +750,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. @@ -761,9 +766,10 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { for i := 0; i < maxPendingResolutions; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 3670d5995..f152c0d83 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { +func DefaultTables(seed uint32) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,7 @@ func DefaultTables() *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: generateRandUint32(), + seed: seed, }, reaperDone: make(chan struct{}, 1), } @@ -268,10 +268,6 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from -// which address can be gathered. Currently, address is only needed for -// prerouting. -// // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { @@ -371,6 +367,7 @@ func (it *IPTables) startReaper(interval time.Duration) { select { case <-it.reaperDone: return + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. case <-time.After(interval): bucket, interval = it.connections.reapUnused(bucket, interval) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 2812c89aa..96cc899bb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than // forwarded). -// -// TODO(gvisor.dev/issue/170): Other flags need to be added after we support -// them. type RedirectTarget struct { // Port indicates port used to redirect. It is immutable. Port uint16 @@ -100,9 +97,6 @@ type RedirectTarget struct { } // Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for Prerouting and calls pkt.Clone(), neither -// of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { @@ -136,34 +130,26 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r panic("redirect target is supported only on output and prerouting hooks") } - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.Port) - // Calculate UDP checksum and set it. if hook == Output { - udpHeader.SetChecksum(0) - netHeader := pkt.Network() - netHeader.SetDestinationAddress(address) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + udpHeader, + false, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + rt.Port, + address, + ) + } else { + udpHeader.SetDestinationPort(rt.Port) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -222,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetChecksum(0) - udpHeader.SetSourcePort(st.Port) - netHeader := pkt.Network() - netHeader.SetSourceAddress(st.Addr) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + header.UDP(pkt.TransportHeader().View()), + true, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + st.Port, + st.Addr, + ) - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -260,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleAccept, 0 } + +func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { + if updateSRCFields { + if fullChecksum { + t.SetSourcePortWithChecksumUpdate(newPort) + } else { + t.SetSourcePort(newPort) + } + } else { + if fullChecksum { + t.SetDestinationPortWithChecksumUpdate(newPort) + } else { + t.SetDestinationPort(newPort) + } + } + + if updatePseudoHeader { + var oldAddr tcpip.Address + if updateSRCFields { + oldAddr = n.SourceAddress() + } else { + oldAddr = n.DestinationAddress() + } + + t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum) + } + + if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok { + if updateSRCFields { + checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr) + } else { + checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr) + } + } else if updateSRCFields { + n.SetSourceAddress(newAddr) + } else { + n.SetDestinationAddress(newAddr) + } +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 93592e7f5..66e5f22ac 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -242,7 +242,6 @@ type IPHeaderFilter struct { func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( - // TODO(gvisor.dev/issue/170): Support other filter fields. transProto tcpip.TransportProtocolNumber dstAddr tcpip.Address srcAddr tcpip.Address @@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return true case Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index ac2fa777e..9623d9c28 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -16,14 +16,14 @@ package stack_test import ( "bytes" - "context" "encoding/binary" "fmt" + "math/rand" "testing" "time" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -52,17 +52,6 @@ const ( linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") defaultPrefixLen = 128 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 10 * time.Second - - // Extra time to use when waiting for an async event to not occur. - // - // Since a negative check is used to make sure an event did not happen, it is - // okay to use a smaller timeout compared to the positive case since execution - // stall in regards to the monotonic clock will not affect the expected - // outcome. - defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -112,11 +101,13 @@ type ndpDADEvent struct { res stack.DADResult } -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool +type ndpOffLinkRouteEvent struct { + nicID tcpip.NICID + subnet tcpip.Subnet + router tcpip.Address + prf header.NDPRoutePreference + // true if route was updated, false if invalidated. + updated bool } type ndpPrefixEvent struct { @@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct { eventType ndpAutoGenAddrEventType } +func (e ndpAutoGenAddrEvent) String() string { + return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType) +} + type ndpRDNSS struct { addrs []tcpip.Address lifetime time.Duration @@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // related events happen for test purposes. type ndpDispatcher struct { dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool + offLinkRouteC chan ndpOffLinkRouteEvent prefixC chan ndpPrefixEvent - rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent dnsslC chan ndpDNSSLEvent @@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add } } -// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. +func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) { + if c := n.offLinkRouteC; c != nil { + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, true, } } - - return n.rememberRouter } -// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. +func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { + if c := n.offLinkRouteC; c != nil { + var prf header.NDPRoutePreference + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, false, } } } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ nicID, @@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip true, } } - - return n.rememberPrefix } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. @@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ nicID, @@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi newAddr, } } - return true } func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { @@ -497,8 +490,9 @@ func TestDADResolve(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ - Clock: clock, - SecureRNG: &secureRNG, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -605,7 +599,10 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, _ := e.ReadContext(context.Background()) + p, ok := e.Read() + if !ok { + t.Fatal("packet didn't arrive") + } // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -731,11 +728,13 @@ func TestDADFail(t *testing.T) { dadConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -761,16 +760,17 @@ func TestDADFail(t *testing.T) { // Wait for DAD to fail and make sure the address did // not get resolved. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) @@ -839,11 +839,13 @@ func TestDADStop(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -861,15 +863,16 @@ func TestDADStop(t *testing.T) { test.stopFn(t, s) // Wait for DAD to fail (since the address was removed during DAD). + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") } if !test.skipFinalAddrCheck { @@ -920,10 +923,12 @@ func TestSetNDPConfigurations(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, })}, + Clock: clock, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -1002,28 +1007,23 @@ func TestSetNDPConfigurations(t *testing.T) { t.Fatal(err) } - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + // Sleep until right before resolution to make sure the address didn't + // resolve on NIC(1) yet. + const delta = 1 + clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { t.Fatal(err) @@ -1032,10 +1032,13 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options -// and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) +// raBuf returns a valid NDP Router Advertisement with options, router +// preference and DHCPv6 configurations specified. +func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { + const flagsByte = 1 + const routerLifetimeOffset = 2 + + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) @@ -1043,19 +1046,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[2:], rl) + binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl) // Populate the Managed Address flag field. if managedAddress { - // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) - // of the RA payload. - raPayload[1] |= (1 << 7) + // The Managed Addresses flag field is the 7th bit of the flags byte. + raPayload[flagsByte] |= 1 << 7 } // Populate the Other Configurations flag field. if otherConfigurations { - // The Other Configurations flag field is the 6th bit of byte #1 - // (0-indexing) of the RA payload. - raPayload[1] |= (1 << 6) + // The Other Configurations flag field is the 6th bit of the flags byte. + raPayload[flagsByte] |= 1 << 6 } + // The Prf field is held in the flags byte. + raPayload[flagsByte] |= byte(prf) << 3 opts := ra.Options() opts.Serialize(optSer) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -1083,7 +1086,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer) } // raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related @@ -1091,18 +1094,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer { + return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{}) } // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { +func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } +// raBufWithPrf returns a valid NDP Router Advertisement with a preference. +// +// Note, raBufWithPrf does not populate any of the RA fields other than the +// Router Lifetime and Default Router Preference fields. +func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{}) +} + // raBufWithPI returns a valid NDP Router Advertisement with a single Prefix // Information option. // @@ -1162,7 +1173,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { config: func(enable bool) ipv6.NDPConfigurations { return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} }, - ra: raBuf(llAddr2, 1000), + ra: raBufSimple(llAddr2, 1000), }, { name: "No Prefix Discovery", @@ -1198,9 +1209,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } ndpConfigs := test.config(enable) ndpConfigs.HandleRAs = handle @@ -1270,8 +1281,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) default: } select { @@ -1297,10 +1308,8 @@ func boolToUint64(v bool) uint64 { return 0 } -// Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { + return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) } func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { @@ -1333,56 +1342,15 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b } } -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestRouterDiscovery(t *testing.T) { + const nicID = 1 + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1391,30 +1359,33 @@ func TestRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { + expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { t.Helper() select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") } } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + var prf header.NDPRoutePreference + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for router discovery event") } } @@ -1423,37 +1394,44 @@ func TestRouterDiscovery(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Rx an RA from lladdr2 with zero lifetime. It should not be // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with 0 lifetime") default: } - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from lladdr2 with a huge lifetime and reserved preference value + // (which should be interpreted as the default (medium) preference value). + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - // Rx an RA from another router (lladdr3) with non-zero lifetime. + // Rx an RA from another router (lladdr3) with non-zero lifetime and + // non-default preference value. const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) + expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) - // Rx an RA from lladdr2 with lesser lifetime. + // Rx an RA from lladdr2 with lesser lifetime and default (medium) + // preference value. const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds)) select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") + case <-ndpDisp.offLinkRouteC: + t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") default: } + // Rx an RA from lladdr2 with a different preference. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. @@ -1461,15 +1439,15 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) @@ -1478,16 +1456,17 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) }) } // TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1500,23 +1479,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { })}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) - if i <= ipv6.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredOffLinkRoutes { select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") @@ -1524,7 +1503,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } else { select { - case <-ndpDisp.routerC: + case <-ndpDisp.offLinkRouteC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") default: } @@ -1538,51 +1517,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) } -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestPrefixDiscovery(t *testing.T) { prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") @@ -1590,10 +1524,10 @@ func TestPrefixDiscovery(t *testing.T) { testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1602,6 +1536,7 @@ func TestPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1662,12 +1597,13 @@ func TestPrefixDiscovery(t *testing.T) { // Wait for prefix2's most recent invalidation job plus some buffer to // expire. + clock.Advance(time.Duration(lifetime) * time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for prefix discovery event") } @@ -1678,17 +1614,6 @@ func TestPrefixDiscovery(t *testing.T) { } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - prefix := tcpip.AddressWithPrefix{ Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, @@ -1696,10 +1621,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1708,6 +1633,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1729,46 +1655,39 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) expectPrefixEvent(subnet, true) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) + clock.Advance(header.NDPInfiniteLifetime - time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(testInfiniteLifetime): + default: t.Fatal("timed out waiting for prefix discovery event") } // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with 0 lifetime. @@ -1781,8 +1700,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1859,17 +1777,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) } +const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second) +const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second) + // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1878,6 +1791,7 @@ func TestAutoGenAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1886,6 +1800,7 @@ func TestAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { @@ -1935,8 +1850,9 @@ func TestAutoGenAddr(t *testing.T) { default: } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -1946,7 +1862,7 @@ func TestAutoGenAddr(t *testing.T) { } // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") @@ -1954,12 +1870,13 @@ func TestAutoGenAddr(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1989,20 +1906,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t // TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when // configured to do so as part of IPv6 Privacy Extensions. func TestAutoGenTempAddr(t *testing.T) { - const ( - nicID = 1 - newMinVL = 5 - newMinVLDuration = newMinVL * time.Second - ) - - savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - ipv6.MaxDesyncFactor = time.Nanosecond + const nicID = 1 prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -2022,218 +1926,211 @@ func TestAutoGenTempAddr(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for i, test := range tests { - i := i - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + for i, test := range tests { + t.Run(test.name, func(t *testing.T) { + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + }, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, + Clock: clock, + }) - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("expected addr auto gen event") } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for DAD event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred - // lifetimes. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum and won't be reached in this test. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 } - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } - }) + default: + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } } // TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not @@ -2241,12 +2138,6 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = time.Nanosecond - tests := []struct { name string dupAddrTransmits uint8 @@ -2262,66 +2153,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, + })}, + Clock: clock, + }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") + default: + t.Fatal("expected addr auto gen event") + } + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD event") + } - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - }) + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + default: + } + }) + } } // TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address @@ -2334,12 +2215,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = 0 - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) @@ -2350,6 +2225,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -2363,6 +2239,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2392,12 +2269,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // Wait for DAD to complete for the stable address then expect the temporary // address to be generated. + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } select { @@ -2405,7 +2283,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -2414,46 +2292,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // regenerated. func TestAutoGenTempAddrRegen(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrValidLifetime: maxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate, + } + clock := faketime.NewManualClock() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2476,36 +2352,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" { t.Fatal(mismatch) } + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false @@ -2516,45 +2399,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) { ndpEP.SetNDPConfigurations(ndpConfigs) } + // Refresh lifetimes and wait for the last temporary address to be deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv) + + // Refresh lifetimes such that the prefix is valid and preferred forever. + // + // This should not affect the lifetimes of temporary addresses because they + // are capped by the maximum valid and preferred lifetimes for temporary + // addresses. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) + // Wait for all the temporary addresses to get invalidated. - tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} - invalidateAfter := newMinVLDuration - 2*regenAfter + invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{}) for _, addr := range tempAddrs { - // Wait for a deprecation then invalidation event, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation jobs could execute in any - // order. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we shouldn't get a deprecation - // event after. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event = %+v", e) - } - case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - - invalidateAfter = regenAfter + expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter) + invalidateAfter = tempAddrRegenenerationTime } - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" { t.Fatal(mismatch) } } @@ -2563,52 +2425,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // regeneration job gets updated when refreshing the address's lifetimes. func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate + maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second) + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime, + MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2, + } + clock := faketime.NewManualClock() + initialTime := clock.NowMonotonic() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -2625,22 +2489,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2648,13 +2513,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // // A new temporary address should be generated after the regeneration // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0)) expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: + } + + effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + + // Advance the clock by the regeneration time but don't expect a new temporary + // address as the prefix is deprecated. + clock.Advance(tempAddrRegenenerationTime) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: } // Prefer the prefix again. @@ -2662,8 +2541,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr2, newAddr) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) + expectAutoGenAddrEvent(tempAddrs[1], newAddr) + // Wait for the first temporary address to be deprecated. + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %s", e) + default: + } // Increase the maximum lifetimes for temporary addresses to large values // then refresh the lifetimes of the prefix. @@ -2674,34 +2560,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // regenerate a new temporary address. Note, new addresses are only // regenerated after the preferred lifetime - the regenerate advance duration // as paased. - ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second - ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + const largeLifetimeSeconds = minVLSeconds * 2 + const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second + ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime + ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } ndpEP := ipv6Ep.(ipv6.NDPEndpoint) ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime) + clock.Advance(largeLifetime - timeSinceInitialTime) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) + // to offset the advement of time to test the first temporary address's + // deprecation after the second was generated + advLess := regenAdv + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv)) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + default: } - - // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration job gets scheduled again. - // - // The maximum lifetime is the sum of the minimum lifetimes for temporary - // addresses + the time that has already passed since the last address was - // generated so that the regeneration job is needed to generate the next - // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout - ndpConfigs.MaxTempAddrValidLifetime = newLifetimes - ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2929,13 +2811,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2945,6 +2828,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2958,7 +2842,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } - return ndpDisp, e, s + return ndpDisp, e, s, clock } // addrForNewConnectionTo returns the local address used when creating a new @@ -3032,7 +2916,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3135,19 +3019,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // when its preferred lifetime expires. func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3165,12 +3041,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -3188,7 +3065,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) @@ -3207,7 +3084,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3216,7 +3093,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3226,6 +3103,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since addr1 is deprecated but // addr2 is not. expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { @@ -3234,7 +3112,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3246,7 +3124,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3256,7 +3134,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3270,7 +3148,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3280,7 +3158,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr2) // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3292,6 +3170,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // cases because the deprecation and invalidation handlers could be handled in // either deprecation then invalidation, or invalidation then deprecation // (which should be cancelled by the invalidation handler). + // + // Since we're about to cause both events to fire, we need the dispatcher + // channel to be able to hold both. + if got, want := len(ndpDisp.autoGenAddrC), 0; got != want { + t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want { + t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2) + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { @@ -3302,21 +3191,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation + // If we get an invalidation event first, we should not get a deprecation // event after. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3353,15 +3242,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // infinite values. func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3385,68 +3265,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + default: + t.Fatal("expected addr auto gen event") + } - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an @@ -3454,12 +3324,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 4 - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3470,137 +3334,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { evl uint32 }{ // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than + // new VL is less than minVLSeconds but was originally greater than // it. { "LargeVLToVLLessThanMinVLForUpdate", 9999, 1, - newMinVL, + minVLSeconds, }, { "LargeVLTo0", 9999, 0, - newMinVL, + minVLSeconds, }, { "InfiniteVLToVLLessThanMinVLForUpdate", infiniteVL, 1, - newMinVL, + minVLSeconds, }, { "InfiniteVLTo0", infiniteVL, 0, - newMinVL, + minVLSeconds, }, - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. + // Should not update VL if original VL was less than minVLSeconds + // and the new VL is also less than minVLSeconds. { "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, + minVLSeconds - 1, + minVLSeconds - 3, + minVLSeconds - 1, }, // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. + // remaining time or is greater than minVLSeconds. { "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, + minVLSeconds + 5, + minVLSeconds + 3, + minVLSeconds + 3, }, { "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, + minVLSeconds - 3, + minVLSeconds - 1, + minVLSeconds - 1, }, { "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, + minVLSeconds - 3, + minVLSeconds + 1, + minVLSeconds + 1, }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - // - // Validate that the VL for the address got set - // to test.evl. - // + // + // Validate that the VL for the address got set + // to test.evl. + // - // The address should not be invalidated until the effective valid - // lifetime has passed. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): - } + // The address should not be invalidated until the effective valid + // lifetime has passed. + const delta = 1 + clock.Advance(time.Duration(test.evl)*time.Second - delta) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + default: + } - // Wait for the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + // Wait for the invalidation event. + clock.Advance(delta) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrRemoval tests that when auto-generated addresses are removed @@ -3613,6 +3469,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3621,6 +3478,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3654,10 +3512,11 @@ func TestAutoGenAddrRemoval(t *testing.T) { // Wait for the original valid lifetime to make sure the original job got // cancelled/cleaned up. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } } @@ -3668,7 +3527,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3779,6 +3638,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3787,6 +3647,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3816,30 +3677,36 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // Should not get an invalidation event after the PI's invalidation // time. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } +func makeSecretKey(t *testing.T) []byte { + secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes) + n, err := cryptorand.Read(secretKey) + if err != nil { + t.Fatalf("cryptorand.Read(_): %s", err) + } + if l := len(secretKey); n != l { + t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l) + } + return secretKey +} + // TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use // opaque interface identifiers when configured to do so. func TestAutoGenAddrWithOpaqueIID(t *testing.T) { const nicID = 1 const nicName = "nic1" - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + + secretKey := makeSecretKey(t) prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) @@ -3861,6 +3728,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3875,6 +3743,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3913,12 +3782,13 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(validLifetimeSecondPrefix1 * time.Second) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3937,22 +3807,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 - // Needed for the temporary address sub test. - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MaxDesyncFactor = time.Nanosecond - - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -3977,22 +3832,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { @@ -4003,15 +3860,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } } @@ -4022,7 +3880,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name string ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool - prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { @@ -4031,7 +3889,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, - prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { // Receive an RA with prefix1 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) return nil @@ -4045,7 +3903,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name: "LinkLocal address", ndpConfigs: ipv6.NDPConfigurations{}, autoGenLinkLocal: true, - prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { return nil }, addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { @@ -4059,14 +3917,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, - prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { header.InitialTempIID(tempIIDHistory, nil, nicID) // Generate a stable SLAAC address so temporary addresses will be // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) + expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -4078,14 +3936,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } for _, addrType := range addrTypes { - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the parallel - // tests complete and limit the number of parallel tests running at the same - // time to reduce flakes. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. t.Run(addrType.name, func(t *testing.T) { for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { @@ -4094,8 +3944,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrType := addrType t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), @@ -4103,6 +3951,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { e := channel.New(0, 1280, linkAddr1) ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4119,6 +3968,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4126,12 +3976,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:]) // Simulate DAD conflicts so the address is regenerated. for i := uint8(0); i < numFailures; i++ { addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) // Should not have any new addresses assigned to the NIC. if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { @@ -4141,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4151,7 +4001,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4163,8 +4013,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // an address after DAD resolves. if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4174,7 +4024,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4231,13 +4081,12 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrType := addrType t.Run(addrType.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4248,6 +4097,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { RetransmitTimer: retransmitTimer, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4292,7 +4142,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4309,15 +4159,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { const maxRetries = 1 const lifetimeSeconds = 5 - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4326,6 +4168,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -4345,6 +4188,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4375,7 +4219,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict after some time has passed. - time.Sleep(failureTimer) + clock.Advance(failureTimer) rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { @@ -4390,12 +4234,13 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Let the next address resolve. addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) expectAutoGenAddrEvent(addr, newAddr) + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } @@ -4409,6 +4254,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // // We expect either just the invalidation event or the deprecation event // followed by the invalidation event. + clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer) select { case e := <-ndpDisp.autoGenAddrC: if e.eventType == deprecatedAddr { @@ -4421,7 +4267,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4429,7 +4275,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for auto gen addr event") } } @@ -4691,11 +4537,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ) ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ @@ -4738,17 +4582,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ), ) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) } select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) @@ -4770,8 +4614,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpected router event = %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpected off-link route event = %#v", e) default: } select { @@ -4857,12 +4701,11 @@ func TestCleanupNDPState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, @@ -4874,16 +4717,17 @@ func TestCleanupNDPState(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func() (bool, ndpRouterEvent) { + expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { select { - case e := <-ndpDisp.routerC: + case e := <-ndpDisp.offLinkRouteC: return true, e default: } - return false, ndpRouterEvent{} + return false, ndpOffLinkRouteEvent{} } expectPrefixEvent := func() (bool, ndpPrefixEvent) { @@ -4928,8 +4772,8 @@ func TestCleanupNDPState(t *testing.T) { // multiple addresses. e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -4939,8 +4783,8 @@ func TestCleanupNDPState(t *testing.T) { } e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -4950,8 +4794,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -4961,8 +4805,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -5003,14 +4847,14 @@ func TestCleanupNDPState(t *testing.T) { test.cleanupFn(t, s) // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) + gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() + ok, e := expectOffLinkRouteEvent() if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) break } - gotRouterEvents[e]++ + gotOffLinkRouteEvents[e]++ } gotPrefixEvents := make(map[ndpPrefixEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { @@ -5037,14 +4881,14 @@ func TestCleanupNDPState(t *testing.T) { t.FailNow() } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { + t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) } expectedPrefixEvents := map[ndpPrefixEvent]int{ {nicID: nicID1, prefix: subnet1, discovered: false}: 1, @@ -5106,10 +4950,10 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) + clock.Advance(lifetimeSeconds * time.Second) select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") + case <-ndpDisp.offLinkRouteC: + t.Error("unexpected off-link route event") default: } select { @@ -5134,7 +4978,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { ndpDisp := ndpDispatcher{ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -5236,6 +5079,23 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectNoDHCPv6Event() } +var _ rand.Source = (*savingRandSource)(nil) + +type savingRandSource struct { + s rand.Source + + lastInt63 int64 +} + +func (d *savingRandSource) Int63() int64 { + i := d.s.Int63() + d.lastInt63 = i + return i +} +func (d *savingRandSource) Seed(seed int64) { + d.s.Seed(seed) +} + // TestRouterSolicitation tests the initial Router Solicitations that are sent // when a NIC newly becomes enabled. func TestRouterSolicitation(t *testing.T) { @@ -5402,6 +5262,9 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("unexpectedly got a packet = %#v", p) } } + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), + } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5411,8 +5274,10 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, })}, - Clock: clock, + Clock: clock, + RandSource: &randSource, }) + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -5425,19 +5290,27 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) + if remaining != 0 { + maxRtrSolicitDelay := test.maxRtrSolicitDelay + if maxRtrSolicitDelay < 0 { + maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay + } + var actualRtrSolicitDelay time.Duration + if maxRtrSolicitDelay != 0 { + actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay + } + waitForPkt(actualRtrSolicitDelay) remaining-- } subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + for ; remaining != 0; remaining-- { + if test.effectiveRtrSolicitInt != 0 { waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) waitForPkt(time.Nanosecond) } else { - waitForPkt(test.effectiveRtrSolicitInt) + waitForPkt(0) } } @@ -5533,12 +5406,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { + waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) + clock.Advance(timeout) + p, ok := e.Read() if !ok { t.Fatal("timed out waiting for packet") } @@ -5552,6 +5424,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS()) } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5561,6 +5434,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { MaxRtrSolicitationDelay: delay, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5568,13 +5442,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok = e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok = e.Read(); ok { t.Fatal("should not have sent more than one RS message") } } @@ -5582,9 +5454,8 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } @@ -5595,21 +5466,19 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + waitForPkt(clock, delay) + waitForPkt(clock, interval) + waitForPkt(clock, interval) + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") } // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") } }) diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 509f5ce5c..08857e1a9 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -310,7 +310,7 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { *n = neighborCache{ nic: nic, - state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), + state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.randomGenerator), linkRes: r, } n.mu.Lock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 9821a18d3..7de25fe37 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -301,10 +293,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }) } @@ -313,10 +305,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }) @@ -347,10 +339,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -419,10 +411,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -500,12 +492,12 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now().Add(-durationReachableNanos), } wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) } @@ -571,10 +563,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -616,10 +608,10 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -663,10 +655,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -689,10 +681,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -733,10 +725,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -758,10 +750,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -814,20 +806,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -844,10 +836,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -875,10 +867,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -890,10 +882,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -910,10 +902,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -947,10 +939,10 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -973,20 +965,20 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -1027,10 +1019,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -1062,13 +1054,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - startedAt := clock.NowNanoseconds() + startedAt := clock.Now() // The following logic is very similar to overflowCache, but // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1118,7 +1110,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { State: Reachable, // Can be inferred since the frequently used entry is the first to // be created and transitioned to Reachable. - UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(), + UpdatedAt: startedAt.Add(typicalLatency), }, } @@ -1127,12 +1119,12 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1190,12 +1182,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { if !ok { t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1244,10 +1236,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1263,10 +1255,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1301,10 +1293,10 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1405,10 +1397,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1436,10 +1428,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -1455,10 +1447,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { { wantEntries := []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, } if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { @@ -1488,10 +1480,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1518,10 +1510,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1541,10 +1533,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) @@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) { linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..0a59eecdd 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -31,10 +31,10 @@ const ( // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAtNanos int64 + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -138,10 +138,10 @@ func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState * // calling `setStateLocked`. func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: cache.nic.stack.clock.Now(), } n := &neighborEntry{ cache: cache, @@ -166,14 +166,20 @@ func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) { if ch := e.mu.done; ch != nil { close(ch) e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current + // Dequeue the pending packets asynchronously to not hold up the current // goroutine as writing packets may be a costly operation. // // At the time of writing, when writing packets, a neighbor's link address // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + // link resolution queue's lock. Dequeuing packets asynchronously avoids a + // lock ordering violation. + // + // NB: this is equivalent to spawning a goroutine directly using the go + // keyword but allows tests that use manual clocks to deterministically + // wait for this work to complete. + e.cache.nic.stack.clock.AfterFunc(0, func() { + e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + }) } } @@ -224,7 +230,7 @@ func (e *neighborEntry) cancelTimerLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() e.dispatchRemoveEventLocked() e.cancelTimerLocked() // TODO(https://gvisor.dev/issues/5583): test the case where this function is @@ -246,7 +252,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.mu.neigh.State e.mu.neigh.State = next - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() config := e.nudState.Config() switch next { @@ -307,7 +313,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -354,14 +360,14 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Unknown, Unreachable: prev := e.mu.neigh.State e.mu.neigh.State = Incomplete - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() switch prev { case Unknown: e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +384,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..59d86d6d4 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ @@ -354,10 +349,10 @@ func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes * EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -415,10 +410,10 @@ func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entry EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -446,7 +441,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { // UpdatedAt should remain the same during address resolution. e.mu.Lock() - startedAt := e.mu.neigh.UpdatedAtNanos + startedAt := e.mu.neigh.UpdatedAt e.mu.Unlock() // Wait for the rest of the reachability probe transmissions, signifying @@ -470,7 +465,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAtNanos, startedAt; got != want { + if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want { t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -485,10 +480,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -547,10 +542,10 @@ func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -644,10 +639,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -678,10 +673,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -757,10 +752,10 @@ func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *tes EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -943,10 +938,10 @@ func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDis EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -998,10 +993,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1050,10 +1045,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1102,10 +1097,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1191,10 +1186,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1243,10 +1238,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1284,10 +1279,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1332,10 +1327,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1391,10 +1386,10 @@ func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTe EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + UpdatedAt: clock.Now(), }, }, } @@ -1443,10 +1438,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1498,10 +1493,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1553,10 +1548,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1645,10 +1640,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1697,10 +1692,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1770,10 +1765,10 @@ func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatc EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + UpdatedAt: clock.Now(), }, }, } @@ -1827,10 +1822,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1882,10 +1877,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -2003,10 +1998,10 @@ func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, lin EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: linkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: linkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -2155,10 +2150,10 @@ func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDD EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -2227,10 +2222,10 @@ func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkR EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dbba2c79f..b854d868c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,41 +773,35 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } transProto := state.proto - // Raw socket packets are delivered based solely on the transport - // protocol number. We do not inspect the payload to ensure it's - // validly formed. - n.stack.demux.deliverRawPacket(protocol, pkt) - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +833,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable @@ -891,6 +872,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo } } +// DeliverRawPacket implements TransportDispatcher. +func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { + // For ICMPv4 only we validate the header length for compatibility with + // raw(7) ICMP_FILTER. The same check is made in Linux here: + // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189. + if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize { + return + } + n.stack.demux.deliverRawPacket(protocol, pkt) +} + // ID implements NetworkInterface. func (n *nic) ID() tcpip.NICID { return n.id diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..ca9822bca 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -16,6 +16,7 @@ package stack import ( "math" + "math/rand" "sync" "time" @@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 { return defaultMaxRandomFactor } -// A Rand is a source of random numbers. -type Rand interface { - // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). - Float32() float32 -} - // NUDState stores states needed for calculating reachable time. type NUDState struct { - rng Rand + clock tcpip.Clock + rng *rand.Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex + mu struct { + sync.RWMutex - config NUDConfigurations + config NUDConfigurations - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified // random number generator for use in recomputing ReachableTime. -func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { +func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState { s := &NUDState{ - rng: rng, + clock: clock, + rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if s.clock.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = s.clock.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..1aeb2f8a5 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -16,6 +16,7 @@ package stack_test import ( "math" + "math/rand" "testing" "time" @@ -28,17 +29,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) @@ -48,12 +47,14 @@ type fakeRand struct { num float32 } -var _ stack.Rand = (*fakeRand)(nil) +var _ rand.Source = (*fakeRand)(nil) -func (f *fakeRand) Float32() float32 { - return f.num +func (f *fakeRand) Int63() int64 { + return int64(f.num * float32(1<<63)) } +func (*fakeRand) Seed(int64) {} + func TestNUDFunctions(t *testing.T) { const nicID = 1 @@ -169,7 +170,7 @@ func TestNUDFunctions(t *testing.T) { t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) } else if test.expectedErr == nil { if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Errorf("neighbors mismatch (-want +got):\n%s", diff) @@ -710,7 +711,8 @@ func TestNUDStateReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) if got, want := s.ReachableTime(), test.want; got != want { t.Errorf("got ReachableTime = %q, want = %q", got, want) } @@ -782,7 +784,8 @@ func TestNUDStateRecomputeReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) old := s.ReachableTime() if got, want := s.ReachableTime(), old; got != want { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 01652fbe7..9192d8433 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -134,7 +134,7 @@ type PacketBuffer struct { // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType - // NICID is the ID of the interface the network packet was received at. + // NICID is the ID of the last interface the network packet was handled at. NICID tcpip.NICID // RXTransportChecksumValidated indicates that transport checksum verification @@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int { func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] if h.length > 0 { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) } if pk.pushed+size > pk.reserved { - panic("not enough headroom reserved") + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } pk.pushed += size h.offset = -pk.pushed diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go index 421fb5c15..c8294eb6e 100644 --- a/pkg/tcpip/stack/rand.go +++ b/pkg/tcpip/stack/rand.go @@ -15,7 +15,7 @@ package stack import ( - mathrand "math/rand" + "math/rand" "gvisor.dev/gvisor/pkg/sync" ) @@ -23,7 +23,7 @@ import ( // lockedRandomSource provides a threadsafe rand.Source. type lockedRandomSource struct { mu sync.Mutex - src mathrand.Source + src rand.Source } func (r *lockedRandomSource) Int63() (n int64) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 85bb87b4b..dfe2c886f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -265,6 +265,11 @@ type TransportDispatcher interface { // // DeliverTransportError takes ownership of the packet buffer. DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) + + // DeliverRawPacket delivers a packet to any subscribed raw sockets. + // + // DeliverRawPacket does NOT take ownership of the packet buffer. + DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -420,7 +425,7 @@ const ( PermanentExpired // Temporary is an endpoint, created on a one-off basis to temporarily - // consider the NIC bound an an address that it is not explictiy bound to + // consider the NIC bound an an address that it is not explicitly bound to // (such as a permanent address). Its reference count must not be biased by 1 // so that the address is removed immediately when references to it are no // longer held. @@ -630,7 +635,7 @@ type NetworkEndpoint interface { // HandlePacket takes ownership of pkt. HandlePacket(pkt *PacketBuffer) - // Close is called when the endpoint is reomved from a stack. + // Close is called when the endpoint is removed from a stack. Close() // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for @@ -968,7 +973,7 @@ type DuplicateAddressDetector interface { // called with the result of the original DAD request. CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition - // SetDADConfiguations sets the configurations for DAD. + // SetDADConfigurations sets the configurations for DAD. SetDADConfigurations(c DADConfigurations) // DuplicateAddressProtocol returns the network protocol the receiver can @@ -979,7 +984,7 @@ type DuplicateAddressDetector interface { // LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target - // address. The request is broadcasted on the local network if a remote link + // address. The request is broadcast on the local network if a remote link // address is not provided. LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error @@ -1072,4 +1077,4 @@ type GSOEndpoint interface { // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. // This isn't a hard limit, because it is never set into packet headers. -const SoftwareGSOMaxSize = (1 << 16) +const SoftwareGSOMaxSize = 1 << 16 diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8814f45a6..81fabe29a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,17 +20,16 @@ package stack import ( - "bytes" "encoding/binary" "fmt" "io" - mathrand "math/rand" + "math/rand" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/atomicbitops" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -40,13 +39,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -116,7 +108,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. - // TODO(gvisor.dev/issue/170): S/R this field. + // TODO(gvisor.dev/issue/4595): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -145,7 +137,7 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. - randomGenerator *mathrand.Rand + randomGenerator *rand.Rand // secureRNG is a cryptographically secure random number generator. secureRNG io.Reader @@ -196,9 +188,9 @@ type Options struct { // TransportProtocols lists the transport protocols to enable. TransportProtocols []TransportProtocolFactory - // Clock is an optional clock source used for timestampping packets. + // Clock is an optional clock used for timekeeping. // - // If no Clock is specified, the clock source will be time.Now. + // If Clock is nil, tcpip.NewStdClock() will be used. Clock tcpip.Clock // Stats are optional statistic counters. @@ -225,15 +217,21 @@ type Options struct { // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data - // returned by rand.Read(). + // returned by the stack secure RNG. // // RandSource must be thread-safe. - RandSource mathrand.Source + RandSource rand.Source - // IPTables are the initial iptables rules. If nil, iptables will allow + // IPTables are the initial iptables rules. If nil, DefaultIPTables will be + // used to construct the initial iptables rules. // all traffic. IPTables *IPTables + // DefaultIPTables is an optional iptables rules constructor that is called + // if IPTables is nil. If both fields are nil, iptables will allow all + // traffic. + DefaultIPTables func(uint32) *IPTables + // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader } @@ -331,23 +329,32 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + if opts.SecureRNG == nil { + opts.SecureRNG = cryptorand.Reader + } + randSrc := opts.RandSource if randSrc == nil { - // Source provided by mathrand.NewSource is not thread-safe so + var v int64 + if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil { + panic(err) + } + // Source provided by rand.NewSource is not thread-safe so // we wrap it in a simple thread-safe version. - randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + randSrc = &lockedRandomSource{src: rand.NewSource(v)} } + randomGenerator := rand.New(randSrc) + seed := randomGenerator.Uint32() if opts.IPTables == nil { - opts.IPTables = DefaultTables() + if opts.DefaultIPTables == nil { + opts.DefaultIPTables = DefaultTables + } + opts.IPTables = opts.DefaultIPTables(seed) } opts.NUDConfigs.resetInvalidFields() - if opts.SecureRNG == nil { - opts.SecureRNG = rand.Reader - } - s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -360,11 +367,11 @@ func New(opts Options) *Stack { handleLocal: opts.HandleLocal, tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), + seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), + randomGenerator: randomGenerator, secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, @@ -804,7 +811,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -856,7 +863,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), @@ -1819,7 +1826,7 @@ func (s *Stack) Seed() uint32 { // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. -func (s *Stack) Rand() *mathrand.Rand { +func (s *Stack) Rand() *rand.Rand { return s.randomGenerator } @@ -1829,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader { return s.secureRNG } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - -func generateRandInt64() int64 { - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - panic(err) - } - buf := bytes.NewReader(b) - var v int64 - if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { - panic(err) - } - return v -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() @@ -1886,9 +1872,8 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index 33824afd0..dfec4258a 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,78 +14,6 @@ package stack -import "time" - // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack - -// saveT is invoked by stateify. -func (t *TCPCubicState) saveT() unixTime { - return unixTime{t.T.Unix(), t.T.UnixNano()} -} - -// loadT is invoked by stateify. -func (t *TCPCubicState) loadT(unix unixTime) { - t.T = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (t *TCPRACKState) saveXmitTime() unixTime { - return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (t *TCPRACKState) loadXmitTime(unix unixTime) { - t.XmitTime = time.Unix(unix.second, unix.nano) -} - -// saveLastSendTime is invoked by stateify. -func (t *TCPSenderState) saveLastSendTime() unixTime { - return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (t *TCPSenderState) loadLastSendTime(unix unixTime) { - t.LastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) saveRTTMeasureTime() unixTime { - return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { - t.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.MeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { - return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { - r.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveSegTime is invoked by stateify. -func (t *TCPEndpointState) saveSegTime() unixTime { - return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} -} - -// loadSegTime is invoked by stateify. -func (t *TCPEndpointState) loadSegTime(unix unixTime) { - t.SegTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 02d54d29b..21951d05a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -166,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -197,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -463,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -920,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -941,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1066,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1118,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1140,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1220,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1248,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1287,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1324,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1574,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1615,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1641,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1660,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1709,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2049,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2113,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2234,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2248,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } + + nicStats := s.NICInfo()[nicid].Stats - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2316,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -2603,15 +2640,17 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } dadConfigs := stack.DefaultDADConfigurations() + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, } e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2629,17 +2668,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { linkLocalAddr := header.LinkLocalAddr(linkAddr1) // Wait for DAD to resolve. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: // We should get a resolution event after 1s (default time to // resolve as per default NDP configurations). Waiting for that // resolution time + an extra 1s without a resolution event // means something is wrong. t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { t.Fatal(err) @@ -3270,8 +3310,9 @@ func TestDoDADWhenNICEnabled(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -3280,6 +3321,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3324,13 +3366,14 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(dadTransmits * retransmitTimer) select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) @@ -3837,8 +3880,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3875,8 +3916,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4223,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { @@ -4394,7 +4433,7 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index ddff6e2d6..90a8ba6cf 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -39,7 +39,7 @@ type TCPCubicState struct { WMax float64 // T is the time when the current congestion avoidance was entered. - T time.Time `state:".(unixTime)"` + T tcpip.MonotonicTime // TimeSinceLastCongestion denotes the time since the current // congestion avoidance was entered. @@ -78,7 +78,7 @@ type TCPCubicState struct { type TCPRACKState struct { // XmitTime is the transmission timestamp of the most recent // acknowledged segment. - XmitTime time.Time `state:".(unixTime)"` + XmitTime tcpip.MonotonicTime // EndSequence is the ending TCP sequence number of the most recent // acknowledged segment. @@ -216,7 +216,7 @@ type TCPRTTState struct { // +stateify savable type TCPSenderState struct { // LastSendTime is the timestamp at which we sent the last segment. - LastSendTime time.Time `state:".(unixTime)"` + LastSendTime tcpip.MonotonicTime // DupAckCount is the number of Duplicate ACKs received. It is used for // fast retransmit. @@ -256,7 +256,7 @@ type TCPSenderState struct { RTTMeasureSeqNum seqnum.Value // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Closed indicates that the caller has closed the endpoint for // sending. @@ -313,7 +313,7 @@ type TCPSACKInfo struct { type RcvBufAutoTuneParams struct { // MeasureTime is the time at which the current measurement was // started. - MeasureTime time.Time `state:".(unixTime)"` + MeasureTime tcpip.MonotonicTime // CopiedBytes is the number of bytes copied to user space since this // measure began. @@ -341,7 +341,7 @@ type RcvBufAutoTuneParams struct { // RTTMeasureTime is the absolute time at which the current RTT // measurement period began. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Disabled is true if an explicit receive buffer is set for the // endpoint. @@ -380,9 +380,6 @@ type TCPSndBufState struct { // SndClosed indicates that the endpoint has been closed for sends. SndClosed bool - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - // PacketTooBigCount is used to notify the main protocol routine how // many times a "packet too big" control packet is received. PacketTooBigCount int @@ -429,7 +426,7 @@ type TCPEndpointState struct { ID TCPEndpointID // SegTime denotes the absolute time when this segment was received. - SegTime time.Time `state:".(unixTime)"` + SegTime tcpip.MonotonicTime // RcvBufState contains information about the state of the endpoint's // receive socket buffer. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 80ad1a9d4..dda57e225 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "math/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -217,13 +216,20 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t netProto: netProto, transProto: transProto, } - epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, flags) + if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil { + return err + } + // Only add this newly created multiportEndpoint if the singleRegisterEndpoint + // succeeded. + if !ok { + epsByNIC.endpoints[bindToDevice] = multiPortEp + } + return nil } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -407,7 +413,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { @@ -470,17 +475,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNIC, ok := eps.endpoints[id] if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + seed: d.stack.Seed(), } + } + if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { + return err + } + // Only add this newly created epsByNIC if registerEndpoint succeeded. + if !ok { eps.endpoints[id] = epsByNIC } - - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) + return nil } func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { @@ -502,7 +511,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum return nil } - return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) + return epsByNIC.checkEndpoint(flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4848495c9..45b09110d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -18,6 +18,7 @@ import ( "io/ioutil" "math" "math/rand" + "strconv" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -84,7 +85,8 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } type headers struct { - srcPort, dstPort uint16 + srcPort uint16 + dstPort uint16 } func newPayload() []byte { @@ -201,6 +203,56 @@ func TestTransportDemuxerRegister(t *testing.T) { } } +func TestTransportDemuxerRegisterMultiple(t *testing.T) { + type test struct { + flags ports.Flags + want tcpip.Error + } + for _, subtest := range []struct { + name string + tests []test + }{ + {"zeroFlags", []test{ + {ports.Flags{}, nil}, + {ports.Flags{}, &tcpip.ErrPortInUse{}}, + }}, + {"multibindFlags", []test{ + // Allow multiple registrations same TransportEndpointID with multibind flags. + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + // Disallow registration w/same ID for a non-multibindflag. + {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, + }}, + } { + t.Run(subtest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + var eps []tcpip.Endpoint + for idx, test := range subtest.tests { + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + eps = append(eps, ep) + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + id := stack.TransportEndpointID{LocalPort: 1} + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { + t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) + } + } + for _, ep := range eps { + ep.Close() + } + }) + } +} + // TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { @@ -208,7 +260,7 @@ func TestBindToDeviceDistribution(t *testing.T) { reuse bool bindToDevice tcpip.NICID } - for _, test := range []struct { + tcs := []struct { name string // endpoints will received the inject packets. endpoints []endpointSockopts @@ -217,29 +269,29 @@ func TestBindToDeviceDistribution(t *testing.T) { wantDistributions map[tcpip.NICID][]float64 }{ { - "BindPortReuse", + name: "BindPortReuse", // 5 endpoints that all have reuse set. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { - "BindToDevice", + name: "BindToDevice", // 3 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: false, bindToDevice: 1}, {reuse: false, bindToDevice: 2}, {reuse: false, bindToDevice: 3}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. @@ -249,9 +301,9 @@ func TestBindToDeviceDistribution(t *testing.T) { }, }, { - "ReuseAndBindToDevice", + name: "ReuseAndBindToDevice", // 6 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 2}, @@ -259,7 +311,7 @@ func TestBindToDeviceDistribution(t *testing.T) { {reuse: true, bindToDevice: 2}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. 1: {0.5, 0.5, 0, 0, 0, 0}, @@ -270,35 +322,42 @@ func TestBindToDeviceDistribution(t *testing.T) { 1000: {0, 0, 0, 0, 0, 1}, }, }, - } { - for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ - "IPv4": ipv4.ProtocolNumber, - "IPv6": ipv6.ProtocolNumber, - } { + } + protos := map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } + + for _, test := range tcs { + for protoName, protoNum := range protos { for device, wantDistribution := range test.wantDistributions { - t.Run(test.name+protoName+string(device), func(t *testing.T) { + t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { + // Create the NICs. var devices []tcpip.NICID for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) + // Create endpoints and bind each to a NIC, sometimes reusing ports. eps := make(map[tcpip.Endpoint]int) - pollChannel := make(chan tcpip.Endpoint) for i, endpoint := range test.endpoints { // Try to receive the data. wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) + t.Cleanup(func() { + wq.EventUnregister(&we) + close(ch) + }) var err tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) } + t.Cleanup(ep.Close) eps[ep] = i go func(ep tcpip.Endpoint) { @@ -307,32 +366,34 @@ func TestBindToDeviceDistribution(t *testing.T) { } }(ep) - defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: dstAddr = testDstAddrV4 case ipv6.ProtocolNumber: dstAddr = testDstAddrV6 default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } - npackets := 100000 - nports := 10000 + // Send packets across a range of ports, checking that packets from + // the same source port are always demultiplexed to the same + // destination endpoint. + npackets := 10_000 + nports := 1_000 if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } - ports := make(map[uint16]tcpip.Endpoint) + endpoints := make(map[uint16]tcpip.Endpoint) stats := make(map[tcpip.Endpoint]int) for i := 0; i < npackets; i++ { // Send a packet. @@ -342,13 +403,13 @@ func TestBindToDeviceDistribution(t *testing.T) { srcPort: testSrcPort + port, dstPort: testDstPort, } - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: c.sendV4Packet(payload, hdrs, device) case ipv6.ProtocolNumber: c.sendV6Packet(payload, hdrs, device) default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } ep := <-pollChannel @@ -357,11 +418,11 @@ func TestBindToDeviceDistribution(t *testing.T) { } stats[ep]++ if i < nports { - ports[uint16(i)] = ep + endpoints[uint16(i)] = ep } else { // Check that all packets from one client are handled by the same // socket. - if want, got := ports[port], ep; want != got { + if want, got := endpoints[port], ep; want != got { t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) } } diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go index 7ce43a68e..371da2f40 100644 --- a/pkg/tcpip/stdclock.go +++ b/pkg/tcpip/stdclock.go @@ -60,11 +60,11 @@ type stdClock struct { // monotonicOffset is assigned maxMonotonic after restore so that the // monotonic time will continue from where it "left off" before saving as part // of S/R. - monotonicOffset int64 `state:"nosave"` + monotonicOffset MonotonicTime `state:"nosave"` // monotonicMU protects maxMonotonic. monotonicMU sync.Mutex `state:"nosave"` - maxMonotonic int64 + maxMonotonic MonotonicTime } // NewStdClock returns an instance of a clock that uses the time package. @@ -76,25 +76,25 @@ func NewStdClock() Clock { var _ Clock = (*stdClock)(nil) -// NowNanoseconds implements Clock.NowNanoseconds. -func (*stdClock) NowNanoseconds() int64 { - return time.Now().UnixNano() +// Now implements Clock.Now. +func (*stdClock) Now() time.Time { + return time.Now() } // NowMonotonic implements Clock.NowMonotonic. -func (s *stdClock) NowMonotonic() int64 { +func (s *stdClock) NowMonotonic() MonotonicTime { sinceBase := time.Since(s.baseTime) if sinceBase < 0 { panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) } - monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + monotonicValue := s.monotonicOffset.Add(sinceBase) s.monotonicMU.Lock() defer s.monotonicMU.Unlock() // Monotonic time values must never decrease. - if monotonicValue > s.maxMonotonic { + if s.maxMonotonic.Before(monotonicValue) { s.maxMonotonic = monotonicValue } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 797778e08..8f2658f64 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -64,17 +64,47 @@ func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } +// MonotonicTime is a monotonic clock reading. +// +// +stateify savable +type MonotonicTime struct { + nanoseconds int64 +} + +// Before reports whether the monotonic clock reading mt is before u. +func (mt MonotonicTime) Before(u MonotonicTime) bool { + return mt.nanoseconds < u.nanoseconds +} + +// After reports whether the monotonic clock reading mt is after u. +func (mt MonotonicTime) After(u MonotonicTime) bool { + return mt.nanoseconds > u.nanoseconds +} + +// Add returns the monotonic clock reading mt+d. +func (mt MonotonicTime) Add(d time.Duration) MonotonicTime { + return MonotonicTime{ + nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(), + } +} + +// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum) +// value that can be stored in a Duration, the maximum (or minimum) duration +// will be returned. To compute t-d for a duration d, use t.Add(-d). +func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration { + return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds)) +} + // A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. type Clock interface { - // NowNanoseconds returns the current real time as a number of - // nanoseconds since the Unix epoch. - NowNanoseconds() int64 + // Now returns the current local time. + Now() time.Time - // NowMonotonic returns a monotonic time value at nanosecond resolution. - NowMonotonic() int64 + // NowMonotonic returns the current monotonic clock reading. + NowMonotonic() MonotonicTime // AfterFunc waits for the duration to elapse and then calls f in its own // goroutine. It returns a Timer that can be used to cancel the call using @@ -435,11 +465,11 @@ type ControlMessages struct { // PacketOwner is used to get UID and GID of the packet. type PacketOwner interface { - // UID returns UID of the packet. - UID() uint32 + // UID returns KUID of the packet. + KUID() uint32 - // GID returns GID of the packet. - GID() uint32 + // GID returns KGID of the packet. + KGID() uint32 } // ReadOptions contains options for Endpoint.Read. @@ -861,6 +891,9 @@ type SettableSocketOption interface { isSettableSocketOption() } +// EndpointState represents the state of an endpoint. +type EndpointState uint8 + // CongestionControlState indicates the current congestion control state for // TCP sender. type CongestionControlState int @@ -897,6 +930,9 @@ type TCPInfoOption struct { // RTO is the retransmission timeout for the endpoint. RTO time.Duration + // State is the current endpoint protocol state. + State EndpointState + // CcState is the congestion control state. CcState CongestionControlState @@ -1552,6 +1588,10 @@ type IPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig *StatCounter + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable *StatCounter + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -1835,37 +1875,104 @@ type UDPStats struct { ChecksumErrors *StatCounter } +// NICNeighborStats holds metrics for the neighbor table. +type NICNeighborStats struct { + // LINT.IfChange(NICNeighborStats) + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats) +} + +// NICPacketStats holds basic packet statistics. +type NICPacketStats struct { + // LINT.IfChange(NICPacketStats) + + // Packets is the number of packets counted. + Packets *StatCounter + + // Bytes is the number of bytes counted. + Bytes *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats) +} + +// NICStats holds NIC statistics. +type NICStats struct { + // LINT.IfChange(NICStats) + + // UnknownL3ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported network protocol. + UnknownL3ProtocolRcvdPackets *StatCounter + + // UnknownL4ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported transport protocol. + UnknownL4ProtocolRcvdPackets *StatCounter + + // MalformedL4RcvdPackets is the number of packets received by a NIC that + // could not be delivered to a transport endpoint because the L4 header could + // not be parsed. + MalformedL4RcvdPackets *StatCounter + + // Tx contains statistics about transmitted packets. + Tx NICPacketStats + + // Rx contains statistics about received packets. + Rx NICPacketStats + + // DisabledRx contains statistics about received packets on disabled NICs. + DisabledRx NICPacketStats + + // Neighbor contains statistics about neighbor entries. + Neighbor NICNeighborStats + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats) +} + +// FillIn returns a copy of s with nil fields initialized to new StatCounters. +func (s NICStats) FillIn() NICStats { + InitStatCounters(reflect.ValueOf(&s).Elem()) + return s +} + // Stats holds statistics about the networking stack. -// -// All fields are optional. type Stats struct { - // UnknownProtocolRcvdPackets is the number of packets received by the - // stack that were for an unknown or unsupported protocol. - UnknownProtocolRcvdPackets *StatCounter + // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less + // ambiguous. - // MalformedRcvdPackets is the number of packets received by the stack - // that were deemed malformed. - MalformedRcvdPackets *StatCounter - - // DroppedPackets is the number of packets dropped due to full queues. + // DroppedPackets is the number of packets dropped at the transport layer. DroppedPackets *StatCounter - // ICMP breaks out ICMP-specific stats (both v4 and v6). + // NICs is an aggregation of every NIC's statistics. These should not be + // incremented using this field, but using the relevant NIC multicounters. + NICs NICStats + + // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4 + // and v6). These should not be incremented using this field, but using the + // relevant NetworkEndpoint ICMP multicounters. ICMP ICMPStats - // IGMP breaks out IGMP-specific stats. + // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint IGMP multicounters. IGMP IGMPStats - // IP breaks out IP-specific stats (both v4 and v6). + // IP is an aggregation of every NetworkEndpoint's IP statistics. These should + // not be incremented using this field, but using the relevant NetworkEndpoint + // IP multicounters. IP IPStats - // ARP breaks out ARP-specific stats. + // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint ARP multicounters. ARP ARPStats - // TCP breaks out TCP-specific stats. + // TCP holds TCP-specific stats. TCP TCPStats - // UDP breaks out UDP-specific stats. + // UDP holds UDP-specific stats. UDP UDPStats } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 269081ff8..c96ae2f02 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "net" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) { } } -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - func TestAddressWithPrefixSubnet(t *testing.T) { tests := []struct { addr Address diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ab2dab60c..181ef799e 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -53,12 +53,14 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 07ba2b837..f9ab7d0af 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) { // Make sure the packet is not dropped by the next rule. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr checker.ICMPv6Type(header.ICMPv6EchoReply))) } +func boolToInt(v bool) uint64 { + if v { + return 1 + } + return 0 +} + +func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[hook] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } +} + func TestForwardingHook(t *testing.T) { const ( nicID1 = 1 @@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) { }, } - setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[stack.Forward] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } - } - - boolToInt := func(v bool) uint64 { - if v { - return 1 - } - return 0 - } - subTests := []struct { name string setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) @@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) { { name: "Drop", - setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), expectForward: false, }, { name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), expectForward: false, }, { name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), expectForward: true, }, { name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), expectForward: true, }, { name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), expectForward: true, }, } @@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) { }) } } + +func TestInputHookWithLocalForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + rx func(*channel.Endpoint) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv4(t, b, + checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) + }, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv6(t, b, + checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv6Addr), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) + }, + }, + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectDrop bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectDrop: false, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on arrival NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on delivered NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), + expectDrop: false, + }, + + { + name: "Drop with input NIC filtering on other NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), + expectDrop: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + }) + + test.rx(e1) + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { + t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) + } + if got := ip2Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) + } + + if p, ok := e1.Read(); ok == subTest.expectDrop { + t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) + } else if !subTest.expectDrop { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + if p, ok := e2.Read(); ok { + t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index c657714ba..27caa0c28 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -17,6 +17,8 @@ package link_resolution_test import ( "bytes" "fmt" + "net" + "runtime" "testing" "time" @@ -27,12 +29,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -283,9 +287,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, } host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -329,7 +335,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // Wait for an error due to link resolution failing, or the endpoint to be // writable. + if test.expectedWriteErr != nil { + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + } else { + clock.RunImmediatelyScheduledJobs() + } <-ch + { var r bytes.Reader r.Reset([]byte{0}) @@ -395,6 +411,242 @@ func TestTCPLinkResolutionFailure(t *testing.T) { } } +func TestForwardingWithLinkResolutionFailure(t *testing.T) { + const ( + incomingNICID = 1 + outgoingNICID = 2 + ttl = 2 + expectedHostUnreachableErrorCount = 1 + ) + outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != arp.ProtocolNumber { + t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) + } + if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(request.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) + } + } + + ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) + } + + snmc := header.SolicitedNodeAddr(dst) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), + checker.SrcAddr(src), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(dst), + )) + } + + icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4HostUnreachable), + ), + ) + } + + icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv6.DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6AddressUnreachable), + ), + ) + } + + tests := []struct { + name string + networkProtocolFactory []stack.NetworkProtocolFactory + networkProtocolNumber tcpip.NetworkProtocolNumber + sourceAddr tcpip.Address + destAddr tcpip.Address + incomingAddr tcpip.AddressWithPrefix + outgoingAddr tcpip.AddressWithPrefix + transportProtocol func(*stack.Stack) stack.TransportProtocol + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) + icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) + mtu uint32 + }{ + { + name: "IPv4 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + networkProtocolNumber: header.IPv4ProtocolNumber, + sourceAddr: tcptestutil.MustParse4("10.0.0.2"), + destAddr: tcptestutil.MustParse4("11.0.0.2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + }, + transportProtocol: icmp.NewProtocol4, + linkResolutionRequestChecker: arpChecker, + icmpReplyChecker: icmpv4Checker, + rx: rxICMPv4EchoRequest, + mtu: ipv4.MaxTotalSize, + }, + { + name: "IPv6 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + networkProtocolNumber: header.IPv6ProtocolNumber, + sourceAddr: tcptestutil.MustParse6("10::2"), + destAddr: tcptestutil.MustParse6("11::2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + }, + transportProtocol: icmp.NewProtocol6, + linkResolutionRequestChecker: ndpChecker, + icmpReplyChecker: icmpv6Checker, + rx: rxICMPv6EchoRequest, + mtu: header.IPv6MinimumMTU, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + + s := stack.New(stack.Options{ + NetworkProtocols: test.networkProtocolFactory, + TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, + Clock: clock, + }) + + // Set up endpoint through which we will receive packets. + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) + } + incomingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.incomingAddr, + } + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + } + + // Set up endpoint through which we will attempt to forward packets. + outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) + outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.outgoingAddr, + } + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.incomingAddr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: test.outgoingAddr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) + } + + test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) + + nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) + if err != nil { + t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) + } + // Trigger the first packet on the endpoint. + clock.RunImmediatelyScheduledJobs() + + for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { + request, ok := outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ARP packet through outgoing NIC") + } + + test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) + + // Advance the clock the span of one request timeout. + clock.Advance(nudConfigs.RetransmitTimer) + } + + // Next, we make a blocking read to retrieve the error packet. This is + // necessary because outgoing packets are dequeued asynchronously when + // link resolution fails, and this dequeue is what triggers the ICMP + // error. + reply, ok := incomingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP packet through incoming NIC") + } + + test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) + + // Since link resolution failed, we don't expect the packet to be + // forwarded. + forwardedPacket, ok := outgoingEndpoint.Read() + if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) + } + + if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { + t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) + } + }) + } +} + func TestGetLinkAddress(t *testing.T) { const ( host1NICID = 1 @@ -449,8 +701,10 @@ func TestGetLinkAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -466,8 +720,20 @@ func TestGetLinkAddress(t *testing.T) { if test.expectedErr == nil { wantRes.LinkAddress = utils.LinkAddr2 } - if diff := cmp.Diff(wantRes, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + select { + case got := <-ch: + if diff := cmp.Diff(wantRes, got); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("event didn't arrive") } }) } @@ -544,8 +810,10 @@ func TestRouteResolvedFields(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -575,8 +843,20 @@ func TestRouteResolvedFields(t *testing.T) { if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + + select { + case got := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatalf("event didn't arrive") } if test.expectedErr != nil { @@ -789,11 +1069,16 @@ func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.c <- e } -func (d *nudDispatcher) waitForEvent(want eventInfo) error { - if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) +func (d *nudDispatcher) expectEvent(want eventInfo) error { + select { + case got := <-d.c: + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil + default: + return fmt.Errorf("event didn't arrive") } - return nil } // TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it @@ -804,7 +1089,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address neighborAddr tcpip.Address - getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) isHost1Listener bool }{ { @@ -812,23 +1097,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -836,23 +1123,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -860,23 +1149,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -884,23 +1175,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -908,23 +1201,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -933,23 +1228,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -958,23 +1255,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -983,23 +1282,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -1037,14 +1338,14 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) } } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryAdded, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, }); err != nil { t.Fatalf("error waiting for initial NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1064,7 +1365,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // See NUDConfigurations.BaseReachableTime for more information. maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1072,8 +1373,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for stale NUD event: %s", err) } - listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) - defer listenerEP.Close() + listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) defer clientEP.Close() listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} if err := listenerEP.Bind(listenerAddr); err != nil { @@ -1094,14 +1394,15 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // with confirmation that the neighbor is reachable (indicated by a // successful 3-way handshake). <-clientCH - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for delay NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + <-listenerCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1109,26 +1410,55 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for reachable NUD event: %s", err) } + peerEP, peerWQ, err := listenerEP.Accept(nil) + if err != nil { + t.Fatalf("listenerEP.Accept(): %s", err) + } + defer peerEP.Close() + peerWE, peerCH := waiter.NewChannelEntry(nil) + peerWQ.EventRegister(&peerWE, waiter.ReadableEvents) + // Wait for the neighbor to be stale again then send data to the remote. // // On successful transmission, the neighbor should become reachable // without probing the neighbor as a TCP ACK would be received which is an // indication of the neighbor being reachable. clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for stale NUD event: %s", err) } - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := clientEP.Write(&r, wOpts); err != nil { - t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + } + // Heads up, there is a race here. + // + // Incoming TCP segments are handled in + // tcp.(*endpoint).handleSegmentLocked: + // + // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the + // segment queue and notifies waiting readers (such as this channel) + // + // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment + // and notifies the NUD machinery that the peer is reachable + // + // Thus we must permit a delay between the readable signal and the + // expected NUD event. + // + // At the time of writing, this race is reliably hit with gotsan. + <-peerCH + for len(nudDisp.c) == 0 { + runtime.Gosched() } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1141,7 +1471,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // TCP should not mark the route reachable and NUD should go through the // probe state. clock.Advance(nudConfigs.DelayFirstProbeTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1149,7 +1479,16 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for probe NUD event: %s", err) } } - if err := nudDisp.waitForEvent(eventInfo{ + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := peerEP.Write(&r, wOpts); err != nil { + t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err) + } + } + <-clientCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 87d36e1dd..b2008f0b2 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -44,20 +44,17 @@ type ndpDispatcher struct{} func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { - return false +func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) { } -func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} +func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {} -func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} -func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return true +func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index f84d399fb..94b580a70 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -109,3 +109,15 @@ func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) er return nil } + +// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on +// error. +// +// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func MustParseLink(addr string) tcpip.LinkAddress { + parsed, err := tcpip.ParseMACAddress(addr) + if err != nil { + panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) + } + return parsed +} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index 1633d0aeb..ed1ed8ac6 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -15,6 +15,7 @@ package tcpip_test import ( + "math" "sync" "testing" "time" @@ -22,10 +23,85 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) +func TestMonotonicTimeBefore(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.Before(mt) { + t.Errorf("%#v.Before(%#v)", mt, mt) + } + + one := mt.Add(1) + if one.Before(mt) { + t.Errorf("%#v.Before(%#v)", one, mt) + } + if !mt.Before(one) { + t.Errorf("!%#v.Before(%#v)", mt, one) + } +} + +func TestMonotonicTimeAfter(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.After(mt) { + t.Errorf("%#v.After(%#v)", mt, mt) + } + + one := mt.Add(1) + if mt.After(one) { + t.Errorf("%#v.After(%#v)", mt, one) + } + if !one.After(mt) { + t.Errorf("!%#v.After(%#v)", one, mt) + } +} + +func TestMonotonicTimeAddSub(t *testing.T) { + var mt tcpip.MonotonicTime + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + min := mt.Add(math.MinInt64) + max := mt.Add(math.MaxInt64) + + if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow) + } + + if got, want := min.Sub(min), time.Duration(0); want != got { + t.Errorf("got min.Sub(min) = %d, want %d", got, want) + } + if got, want := max.Sub(max), time.Duration(0); want != got { + t.Errorf("got max.Sub(max) = %d, want %d", got, want) + } + + if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want { + t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow) + } + if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want { + t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow) + } +} + +func TestMonotonicTimeSub(t *testing.T) { + var mt tcpip.MonotonicTime + + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow) + } +} + const ( shortDuration = 1 * time.Nanosecond middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second ) func TestJobReschedule(t *testing.T) { @@ -53,10 +129,14 @@ func TestJobReschedule(t *testing.T) { wg.Wait() } +func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) { + return tcpip.NewStdClock(), time.After +} + func TestJobExecution(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -68,7 +148,7 @@ func TestJobExecution(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -76,14 +156,14 @@ func TestJobExecution(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -99,7 +179,7 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -107,14 +187,14 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -128,7 +208,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } job.Schedule(shortDuration) @@ -136,7 +216,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -144,14 +224,14 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -167,14 +247,19 @@ func TestJobImmediatelyCancel(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } } +func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) { + clock, after := stdClockWithAfter() + return clock, after, time.Sleep +} + func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -189,7 +274,7 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration * 2) + sleep(middleDuration * 2) job.Cancel() lock.Unlock() } @@ -199,14 +284,14 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration * 2): + case <-after(middleDuration * 2): } } func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -215,7 +300,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration) + sleep(middleDuration) job.Cancel() job.Schedule(shortDuration) } @@ -224,7 +309,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -232,14 +317,14 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -255,7 +340,7 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -263,6 +348,6 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 7e5c79776..bbc0e3ecc 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -38,3 +38,22 @@ go_library( "//pkg/waiter", ], ) + +go_test( + name = "icmp_x_test", + size = "small", + srcs = ["icmp_test.go"], + deps = [ + ":icmp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8afde7fca..cb316d27a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -16,6 +16,7 @@ package icmp import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -26,14 +27,12 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// TODO(https://gvisor.dev/issues/5623): Unit test this package. - // +stateify savable type icmpPacket struct { icmpPacketEntry senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` } type endpointState int @@ -133,7 +132,8 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) } // Close the receive list and drain it. @@ -160,7 +160,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { @@ -193,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: p.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -304,6 +304,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { return 0, &tcpip.ErrNoRoute{} @@ -348,8 +351,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return int64(len(v)), nil } +var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) + +// HasNIC implements tcpip.SocketOptionsHandler. +func (e *endpoint) HasNIC(id int32) bool { + return e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { return nil } @@ -390,7 +400,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -606,18 +616,19 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) return id, err } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -747,8 +758,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -756,8 +765,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -800,7 +807,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 28a56a2d5..b8b839e4a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,23 @@ package icmp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *icmpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *icmpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves icmpPacket.data field. func (p *icmpPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go new file mode 100644 index 000000000..cc950cbde --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_test.go @@ -0,0 +1,235 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package. +// See the issue for remaining areas of work. + +var ( + localV4Addr1 = testutil.MustParse4("10.0.0.1") + localV4Addr2 = testutil.MustParse4("10.0.0.2") + remoteV4Addr = testutil.MustParse4("10.0.0.3") +) + +func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint { + t.Helper() + + ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */) + t.Cleanup(ep.Close) + + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { + wep = sniffer.New(ep) + } + + opts := stack.NICOptions{Name: name} + if err := s.CreateNICWithOptions(id, wep, opts); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) + } + + if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) + } + + s.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: id, + }) + + return ep +} + +func writePayload(buf []byte) { + for i := range buf { + buf[i] = byte(i) + } +} + +func newICMPv4EchoRequest(payloadSize uint32) buffer.View { + buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize)) + writePayload(buf[header.ICMPv4MinimumSize:]) + + icmp := header.ICMPv4(buf) + icmp.SetType(header.ICMPv4Echo) + // No need to set the checksum; it is reset by the socket before the packet + // is sent. + + return buf +} + +// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket +// when SO_BINDTODEVICE is set to the non-default NIC for that subnet. +// +// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to +// the version of IP. +func TestWriteUnboundWithBindToDevice(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + HandleLocal: true, + }) + + // Add two NICs, both with default routes on the same subnet. The first NIC + // added will be the default NIC for that subnet. + defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1) + alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2) + + socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err) + } + defer socket.Close() + + echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize + + // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC + // to be added is the default NIC to send packets when not explicitly bound. + { + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was sent out the default NIC. + p, ok := defaultEP.Read() + if !ok { + t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr1), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + + // Verify the packet was not sent out the alternate NIC. + if p, ok := alternateEP.Read(); ok { + t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) + } + } + + // Send a packet with SO_BINDTODEVICE. This exercises reliance on + // SO_BINDTODEVICE to route the packet to the alternate NIC. + { + // Use SO_BINDTODEVICE to send over the alternate NIC by default. + socket.SocketOptions().SetBindToDevice(2) + + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was not sent out the default NIC. + if p, ok := defaultEP.Read(); ok { + t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p) + } + + // Verify the packet was sent out the alternate NIC. + p, ok := alternateEP.Read() + if !ok { + t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr2), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + } + + // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing + // the device binding will fallback to using the default NIC to send + // packets. + { + socket.SocketOptions().SetBindToDevice(0) + + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was sent out the default NIC. + p, ok := defaultEP.Read() + if !ok { + t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr1), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + + // Verify the packet was not sent out the alternate NIC. + if p, ok := alternateEP.Read(); ok { + t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) + } + } +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 47f7dd1cb..fa82affc1 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -123,8 +123,6 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. - // // Right now, the Parse() method is tied to enabled protocols passed into // stack.New. This works for UDP and TCP, but we handle ICMP traffic even // when netstack users don't pass ICMP as a supported protocol. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 496eca581..8e7bb6c6e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -27,6 +27,7 @@ package packet import ( "fmt" "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type packet struct { packetEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress // packetInfo holds additional information like the protocol @@ -159,7 +159,7 @@ func (ep *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -189,7 +189,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul Total: packet.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: packet.timestampNS, + Timestamp: packet.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -208,7 +208,6 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul } func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { - // TODO(gvisor.dev/issue/173): Implement. return 0, &tcpip.ErrInvalidOptionValue{} } @@ -220,19 +219,19 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes *tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -244,8 +243,6 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi // Bind implements tcpip.Endpoint.Bind. func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - // TODO(gvisor.dev/issue/173): Add Bind support. - // "By default, all packets of the specified protocol type are passed // to a packet socket. To get packets only from a specific interface // use bind(2) specifying an address in a struct sockaddr_ll to bind @@ -318,7 +315,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -339,7 +336,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -385,7 +382,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(gvisor.dev/issue/173): Return network protocol. if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader().View()) @@ -424,7 +420,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, default: panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) } - } else { // Raw packets need their ethernet headers prepended before // queueing. @@ -451,7 +446,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } } - packet.timestampNS = ep.stack.Clock().NowNanoseconds() + packet.receivedAt = ep.stack.Clock().Now() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() @@ -484,7 +479,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { } // SetOwner implements tcpip.Endpoint.SetOwner. -func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*endpoint) SetOwner(tcpip.PacketOwner) {} // SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 5bd860d20..e729921db 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,11 +15,23 @@ package packet import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *packet) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *packet) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves packet.data field. func (p *packet) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index bcec3d2e7..b6687911a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -27,6 +27,7 @@ package raw import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type rawPacket struct { rawPacketEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress } @@ -183,7 +183,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.mu.Lock() @@ -219,7 +219,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: pkt.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: pkt.timestampNS, + Timestamp: pkt.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -286,26 +286,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return nil, nil, nil, &tcpip.ErrBadBuffer{} } - // If this is an unassociated socket and callee provided a nonzero - // destination address, route using that address. - if e.ops.GetHeaderIncluded() { - ip := header.IPv4(payloadBytes) - if !ip.IsValid(len(payloadBytes)) { - return nil, nil, nil, &tcpip.ErrInvalidOptionValue{} - } - dstAddr := ip.DestinationAddress() - // Update dstAddr with the address in the IP header, unless - // opts.To is set (e.g. if sendto specifies a specific - // address). - if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { - opts.To = &tcpip.FullAddress{ - NIC: 0, // NIC is unset. - Addr: dstAddr, // The address from the payload. - Port: 0, // There are no ports here. - } - } - } - // Did the user caller provide a destination? If not, use the connected // destination. if opts.To == nil { @@ -402,7 +382,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Find a route to the destination. - route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) + route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) if err != nil { return err } @@ -428,7 +408,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -439,7 +419,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -513,12 +493,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -621,7 +601,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV - packet.timestampNS = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 5d6f2709c..39669b445 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,11 +15,23 @@ package raw import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *rawPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *rawPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves rawPacket.data field. func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0f20d3856..8436d2cf0 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -41,7 +41,6 @@ go_library( "protocol.go", "rack.go", "rcv.go", - "rcv_state.go", "reno.go", "reno_recovery.go", "sack.go", @@ -53,7 +52,6 @@ go_library( "segment_state.go", "segment_unsafe.go", "snd.go", - "snd_state.go", "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", @@ -134,6 +132,7 @@ go_test( deps = [ "//pkg/sleep", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d4bd4e80e..d807b13b7 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -114,8 +114,8 @@ type listenContext struct { } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. -func timeStamp() uint32 { - return uint32(time.Now().Unix()>>6) & tsMask +func timeStamp(clock tcpip.Clock) uint32 { + return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Seconds()) >> 6 & tsMask } // newListenContext creates a new listen context. @@ -171,7 +171,7 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // createCookie creates a SYN cookie for the given id and incoming sequence // number. func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) v += (l.cookieHash(id, ts, 1) + data) & hashMask return seqnum.Value(v) @@ -181,7 +181,7 @@ func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Va // sequence number. If it is, it also returns the data originally encoded in the // cookie when createCookie was called. func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) cookieTS := v >> tsOffset if ((ts - cookieTS) & tsMask) > maxTSDiff { @@ -247,7 +247,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - isn := generateSecureISN(s.id, l.stack.Seed()) + isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed()) ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err @@ -550,7 +550,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed e.rcvQueueInfo.rcvQueueMu.Unlock() - if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + if rcvClosed || s.flags.Contains(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST @@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } switch { + case s.flags.Contains(header.TCPFlagRst): + e.stack.Stats().DroppedPackets.Increment() + return nil + case s.flags == header.TCPFlagSyn: if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() @@ -591,7 +595,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } @@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() return nil - case (s.flags & header.TCPFlagAck) != 0: + case s.flags.Contains(header.TCPFlagAck): if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err mss: rcvdSynOptions.MSS, }) + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) { + s.incRef() + n.newSegmentWaker.Assert() + } + // Do the delivery in a separate goroutine so // that we don't block the listen loop in case // the application is slow to accept or stops @@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil default: + e.stack.Stats().DroppedPackets.Increment() return nil } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5e03e7715..e39d1623d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -19,7 +19,6 @@ import ( "math" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -92,7 +91,7 @@ type handshake struct { rcvWndScale int // startTime is the time at which the first SYN/SYN-ACK was sent. - startTime time.Time + startTime tcpip.MonotonicTime // deferAccept if non-zero will drop the final ACK for a passive // handshake till an ACK segment with data is received or the timeout is @@ -147,21 +146,16 @@ func FindWndScale(wnd seqnum.Size) int { // resetState resets the state of the handshake object such that it becomes // ready for a new 3-way handshake. func (h *handshake) resetState() { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - h.state = handshakeSynSent h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the // recommendation here https://tools.ietf.org/html/rfc6528#page-3. -func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { +func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value { isnHasher := jenkins.Sum32(seed) isnHasher.Write([]byte(id.LocalAddress)) isnHasher.Write([]byte(id.RemoteAddress)) @@ -180,7 +174,7 @@ func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { // // Which sort of guarantees that we won't reuse the ISN for a new // connection for the same tuple for at least 274s. - isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + isn := isnHasher.Sum32() + uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Nanoseconds()>>6) return seqnum.Value(isn) } @@ -212,7 +206,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea // a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in // response. func (h *handshake) checkAck(s *segment) bool { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber != h.iss+1 { // RFC 793, page 36, states that a reset must be generated when // the connection is in any non-synchronized state and an // incoming segment acknowledges something not yet sent. The @@ -230,8 +224,8 @@ func (h *handshake) checkAck(s *segment) bool { func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. - if s.flagIsSet(header.TCPFlagRst) { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + if s.flags.Contains(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == h.iss+1 { // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK // was acceptable then signal the user "error: connection reset", drop // the segment, enter CLOSED state, delete TCB, and return." @@ -249,7 +243,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // We are in the SYN-SENT state. We only care about segments that have // the SYN flag. - if !s.flagIsSet(header.TCPFlagSyn) { + if !s.flags.Contains(header.TCPFlagSyn) { return nil } @@ -270,7 +264,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // If this is a SYN ACK response, we only need to acknowledge the SYN // and the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { h.state = handshakeCompleted h.ep.transitionToStateEstablishedLocked(h) @@ -316,7 +310,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. func (h *handshake) synRcvdState(s *segment) tcpip.Error { - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { @@ -340,13 +334,13 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { return nil } - if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { + if s.flags.Contains(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole // process, except that we don't reset the timer. ack := s.sequenceNumber.Add(s.logicalLen()) seq := seqnum.Value(0) - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) @@ -378,10 +372,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { // If deferAccept is not zero and this is a bare ACK and the // timeout is not hit then drop the ACK. - if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + if h.deferAccept != 0 && s.data.Size() == 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) < h.deferAccept { h.acked = true h.ep.stack.Stats().DroppedPackets.Increment() return nil @@ -412,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.ep.transitionToStateEstablishedLocked(h) - // If the segment has data then requeue it for the receiver - // to process it again once main loop is started. - if s.data.Size() > 0 { + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) { s.incRef() - h.ep.enqueueSegment(s) + h.ep.newSegmentWaker.Assert() } return nil } @@ -426,7 +420,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window - if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { + if !s.flags.Contains(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) } @@ -474,7 +468,7 @@ func (h *handshake) processSegments() tcpip.Error { // start sends the first SYN/SYN-ACK. It does not block, even if link address // resolution is required. func (h *handshake) start() { - h.startTime = time.Now() + h.startTime = h.ep.stack.Clock().NowMonotonic() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { @@ -527,7 +521,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -552,7 +546,7 @@ func (h *handshake) complete() tcpip.Error { // The last is required to provide a way for the peer to complete // the connection with another ACK or data (as ACKs are never // retransmitted on their own). - if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + if h.active || !h.acked || h.deferAccept != 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, @@ -608,15 +602,15 @@ func (h *handshake) complete() tcpip.Error { type backoffTimer struct { timeout time.Duration maxTimeout time.Duration - t *time.Timer + t tcpip.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { +func newBackoffTimer(clock tcpip.Clock, timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} - bt.t = time.AfterFunc(timeout, f) + bt.t = clock.AfterFunc(timeout, f) return bt, nil } @@ -634,7 +628,7 @@ func (bt *backoffTimer) stop() { } func parseSynSegmentOptions(s *segment) header.TCPSynOptions { - synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) + synOpts := header.ParseSynOptions(s.options, s.flags.Contains(header.TCPFlagAck)) if synOpts.TS { s.parsedOptions.TSVal = synOpts.TSVal s.parsedOptions.TSEcr = synOpts.TSEcr @@ -915,30 +909,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se return err } -func (e *endpoint) handleWrite() { - e.sndQueueInfo.sndQueueMu.Lock() - next := e.drainSendQueueLocked() - e.sndQueueInfo.sndQueueMu.Unlock() - - e.sendData(next) -} - -// Move packets from send queue to send list. -// -// Precondition: e.sndBufMu must be locked. -func (e *endpoint) drainSendQueueLocked() *segment { - first := e.sndQueueInfo.sndQueue.Front() - if first != nil { - e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue) - e.sndQueueInfo.SndBufInQueue = 0 - } - return first -} - // Precondition: e.mu must be locked. func (e *endpoint) sendData(next *segment) { // Initialize the next segment to write if it's currently nil. if e.snd.writeNext == nil { + if next == nil { + return + } e.snd.writeNext = next } @@ -946,17 +923,6 @@ func (e *endpoint) sendData(next *segment) { e.snd.sendData() } -func (e *endpoint) handleClose() { - if !e.EndpointState().connected() { - return - } - // Drain the send queue. - e.handleWrite() - - // Mark send side as closed. - e.snd.Closed = true -} - // resetConnectionLocked puts the endpoint in an error state with the given // error code and sends a RST if and only if the error is not ErrConnectionReset // indicating that the connection is being reset due to receiving a RST. This @@ -1136,7 +1102,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - if e.EndpointState().closed() { + if state := e.EndpointState(); state.closed() || state == StateTimeWait { return nil } s := e.segmentQueue.dequeue() @@ -1188,11 +1154,11 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // the TCPEndpointState after the segment is processed. defer e.probeSegmentLocked() - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { return false, err } - } else if s.flagIsSet(header.TCPFlagSyn) { + } else if s.flags.Contains(header.TCPFlagSyn) { // See: https://tools.ietf.org/html/rfc5961#section-4.1 // 1) If the SYN bit is set, irrespective of the sequence number, TCP // MUST send an ACK (also referred to as challenge ACK) to the remote @@ -1216,7 +1182,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // should then rely on SYN retransmission from the remote end to // re-establish the connection. e.snd.maybeSendOutOfWindowAck(s) - } else if s.flagIsSet(header.TCPFlagAck) { + } else if s.flags.Contains(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. s.window <<= e.snd.SndWndScale @@ -1235,7 +1201,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - state := e.state + state := e.EndpointState() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1267,7 +1233,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // If a userTimeout is set then abort the connection if it is // exceeded. - if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + if userTimeout != 0 && e.stack.Clock().NowMonotonic().Sub(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() return &tcpip.ErrTimeout{} @@ -1322,7 +1288,7 @@ func (e *endpoint) disableKeepaliveTimer() { // segments. func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() - var closeTimer *time.Timer + var closeTimer tcpip.Timer var closeWaker sleep.Waker epilogue := func() { @@ -1408,14 +1374,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ { w: &e.sndQueueInfo.sndWaker, f: func() tcpip.Error { - e.handleWrite() - return nil - }, - }, - { - w: &e.sndQueueInfo.sndCloseWaker, - f: func() tcpip.Error { - e.handleClose() + e.sendData(nil /* next */) return nil }, }, @@ -1480,11 +1439,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return &tcpip.ErrConnectionReset{} } - if n¬ifyClose != 0 && closeTimer == nil { - if e.EndpointState() == StateFinWait2 && e.closed { + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is still + // established. This can occur when notifyClose + // was asserted just before becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + if closeTimer == nil { + closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } } } @@ -1721,7 +1688,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { var timeWaitWaker sleep.Waker s.AddWaker(&timeWaitWaker, timeWaitDone) - timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert) defer timeWaitTimer.Stop() for { diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 962f1d687..6985194bb 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -41,7 +41,7 @@ type cubicState struct { func newCubicCC(s *sender) *cubicState { return &cubicState{ TCPCubicState: stack.TCPCubicState{ - T: time.Now(), + T: s.ep.stack.Clock().NowMonotonic(), Beta: 0.7, C: 0.4, }, @@ -60,7 +60,7 @@ func (c *cubicState) enterCongestionAvoidance() { // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { c.K = 0 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) } @@ -115,14 +115,15 @@ func (c *cubicState) cubicCwnd(t float64) float64 { // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.T).Seconds() + elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T) + elapsedSeconds := elapsed.Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.WC = c.cubicCwnd(elapsed - c.K) + c.WC = c.cubicCwnd(elapsedSeconds - c.K) // Compute the TCP friendly estimate of the congestion window. - c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. @@ -134,7 +135,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.T) + srtt).Seconds() + tEst := (elapsed + srtt).Seconds() wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. @@ -151,7 +152,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -162,7 +163,7 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.numCongestionEvents = 0 c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -193,7 +194,7 @@ func (c *cubicState) fastConvergence() { // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() } // reduceSlowStartThreshold returns new SsThresh as described in diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 512053a04..dff7cb89c 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -16,10 +16,11 @@ package tcp import ( "encoding/binary" + "math/rand" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -141,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) { // in-order. type dispatcher struct { processors []processor - seed uint32 - wg sync.WaitGroup + // seed is a random secret for a jenkins hash. + seed uint32 + wg sync.WaitGroup } -func (d *dispatcher) init(nProcessors int) { +func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { d.close() d.wait() d.processors = make([]processor, nProcessors) - d.seed = generateRandUint32() + d.seed = rng.Uint32() for i := range d.processors { p := &d.processors[i] p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) @@ -172,12 +174,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, clock, pkt) if !s.parse(pkt.RXTransportChecksumValidated) { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() @@ -185,7 +186,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans } if !s.csumValid { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.ChecksumErrors.Increment() ep.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() @@ -213,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans d.selectProcessor(id).queueEndpoint(ep) } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { var payload [4]byte binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f148d505d..5342aacfd 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -421,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) { r.Reset(data) nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) + tcp = header.IPv4(b).Payload() if string(tcp.Payload()) != data { t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 50d39cbad..4acddc959 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -20,12 +20,12 @@ import ( "fmt" "io" "math" + "math/rand" "runtime" "strings" "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,19 +38,15 @@ import ( ) // EndpointState represents the state of a TCP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - // Endpoint states internal to netstack. These map to the TCP state CLOSED. - StateInitial EndpointState = iota - StateBound - StateConnecting // Connect() called, but the initial SYN hasn't been sent. - StateError - - // TCP protocol states. + _ EndpointState = iota + // TCP protocol states in sync with the definitions in + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13 StateEstablished StateSynSent StateSynRecv @@ -62,6 +58,12 @@ const ( StateLastAck StateListen StateClosing + + // Endpoint states internal to netstack. + StateInitial + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError ) const ( @@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool { } } +// internal returns true when the state is netstack internal. +func (s EndpointState) internal() bool { + switch s { + case StateInitial, StateBound, StateConnecting, StateError: + return true + default: + return false + } +} + // handshake returns true when s is one of the states representing an endpoint // in the middle of a TCP handshake. func (s EndpointState) handshake() bool { @@ -281,16 +293,9 @@ type sndQueueInfo struct { sndQueueMu sync.Mutex `state:"nosave"` stack.TCPSndBufState - // sndQueue holds segments that are ready to be sent. - sndQueue segmentList `state:"wait"` - - // sndWaker is used to signal the protocol goroutine when segments are - // added to the `sndQueue`. + // sndWaker is used to signal the protocol goroutine when there may be + // segments that need to be sent. sndWaker sleep.Waker `state:"manual"` - - // sndCloseWaker is used to notify the protocol goroutine when the send - // side is closed. - sndCloseWaker sleep.Waker `state:"manual"` } // rcvQueueInfo contains the endpoint's rcvQueue and associated metadata. @@ -422,12 +427,12 @@ type endpoint struct { // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState `state:".(EndpointState)"` + state uint32 `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the // endpoint state at restore time as the socket is moved to it's correct // state. - origEndpointState EndpointState `state:"nosave"` + origEndpointState uint32 `state:"nosave"` isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` @@ -468,7 +473,7 @@ type endpoint struct { // recentTSTime is the unix time when we last updated // TCPEndpointStateInner.RecentTS. - recentTSTime time.Time `state:".(unixTime)"` + recentTSTime tcpip.MonotonicTime // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -626,7 +631,7 @@ type endpoint struct { // lastOutOfWindowAckTime is the time at which the an ACK was sent in response // to an out of window segment being received by this endpoint. - lastOutOfWindowAckTime time.Time `state:".(unixTime)"` + lastOutOfWindowAckTime tcpip.MonotonicTime } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -747,7 +752,7 @@ func (e *endpoint) ResumeWork() { // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + oldstate := EndpointState(atomic.LoadUint32(&e.state)) switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() @@ -764,18 +769,18 @@ func (e *endpoint) setEndpointState(state EndpointState) { e.stack.Stats().TCP.CurrentEstablished.Decrement() } } - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { e.RecentTS = recentTS - e.recentTSTime = time.Now() + e.recentTSTime = e.stack.Clock().NowMonotonic() } // recentTimestamp returns the value of the recentTS field. @@ -806,11 +811,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue }, sndQueueInfo: sndQueueInfo{ TCPSndBufState: stack.TCPSndBufState{ - SndMTU: int(math.MaxInt32), + SndMTU: math.MaxInt32, }, }, waiterQueue: waiterQueue, - state: StateInitial, + state: uint32(StateInitial), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -870,9 +875,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset() + e.TSOffset = timeStampOffset(e.stack.Rand()) e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) + e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) return e } @@ -1189,7 +1194,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvQueueInfo.rcvQueueMu.Unlock() return } - now := time.Now() + now := e.stack.Clock().NowMonotonic() if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied e.rcvQueueInfo.rcvQueueMu.Unlock() @@ -1544,12 +1549,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) e.sndQueueInfo.SndBufUsed += len(v) - e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) - e.sndQueueInfo.sndQueue.PushBack(s) + e.snd.writeList.PushBack(s) - return e.drainSendQueueLocked(), len(v), nil + return s, len(v), nil }() // Return if either we didn't queue anything or if an error occurred while // attempting to queue data. @@ -1956,6 +1960,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { info := tcpip.TCPInfoOption{} e.LockUser() + if state := e.EndpointState(); state.internal() { + info.State = tcpip.EndpointState(StateClose) + } else { + info.State = tcpip.EndpointState(state) + } snd := e.snd if snd != nil { // We do not calculate RTT before sending the data packets. If @@ -2198,7 +2207,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } @@ -2224,7 +2233,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but // less than 1 second has elapsed since its recentTS was updated then // we cannot reuse the port. - if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + if tcpEP.EndpointState() != StateTimeWait || e.stack.Clock().NowMonotonic().Sub(tcpEP.recentTSTime) < 1*time.Second { tcpEP.UnlockUser() return false, nil } @@ -2245,7 +2254,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { return false, nil } } @@ -2297,7 +2306,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // connection setting here. if !handshake { e.segmentQueue.mu.Lock() - for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} { + for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.TransportEndpointInfo.ID e.sndQueueInfo.sndWaker.Assert() @@ -2355,6 +2364,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.notifyProtocolGoroutine(notifyTickleWorker) return nil } + // Wake up any readers that maybe waiting for the stream to become + // readable. + e.waiterQueue.Notify(waiter.ReadableEvents) } // Close for write. @@ -2370,13 +2382,21 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) - e.sndQueueInfo.sndQueue.PushBack(s) - e.sndQueueInfo.SndBufInQueue++ + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil) + e.snd.writeList.PushBack(s) // Mark endpoint as closed. e.sndQueueInfo.SndClosed = true e.sndQueueInfo.sndQueueMu.Unlock() - e.handleClose() + + // Drain the send queue. + e.sendData(s) + + // Mark send side as closed. + e.snd.Closed = true + + // Wake up any writers that maybe waiting for the stream to become + // writable. + e.waiterQueue.Notify(waiter.WritableEvents) } return nil @@ -2581,7 +2601,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) { id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2731,7 +2751,7 @@ func (e *endpoint) updateSndBufferUsage(v int) { // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { @@ -2848,23 +2868,20 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.TSOffset) + return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { - return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { + d := curTime.Sub(tcpip.MonotonicTime{}) + return uint32(d.Milliseconds()) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending // timestamp values in a timestamp option for a TCP segment. -func timeStampOffset() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } +func timeStampOffset(rng *rand.Rand) uint32 { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // @@ -2874,7 +2891,7 @@ func timeStampOffset() uint32 { // NOTE: This is not completely to spec as normally this should be // initialized in a manner analogous to how sequence numbers are // randomized per connection basis. But for now this is sufficient. - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return rng.Uint32() } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint @@ -2909,7 +2926,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { s := stack.TCPEndpointState{ TCPEndpointStateInner: e.TCPEndpointStateInner, ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), - SegTime: time.Now(), + SegTime: e.stack.Clock().NowMonotonic(), Receiver: e.rcv.TCPReceiverState, Sender: e.snd.TCPSenderState, } @@ -2937,7 +2954,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { if cubic, ok := e.snd.cc.(*cubicState); ok { s.Sender.Cubic = cubic.TCPCubicState - s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) + s.Sender.Cubic.TimeSinceLastCongestion = e.stack.Clock().NowMonotonic().Sub(s.Sender.Cubic.T) } s.Sender.RACKState = e.snd.rc.TCPRACKState @@ -3029,14 +3046,16 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { // allowOutOfWindowAck returns true if an out-of-window ACK can be sent now. func (e *endpoint) allowOutOfWindowAck() bool { - var limit stack.TCPInvalidRateLimitOption - if err := e.stack.Option(&limit); err != nil { - panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) - } + now := e.stack.Clock().NowMonotonic() - now := time.Now() - if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { - return false + if e.lastOutOfWindowAckTime != (tcpip.MonotonicTime{}) { + var limit stack.TCPInvalidRateLimitOption + if err := e.stack.Option(&limit); err != nil { + panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) + } + if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { + return false + } } e.lastOutOfWindowAckTime = now diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 6e9777fe4..952ccacdd 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -154,20 +154,25 @@ func (e *endpoint) afterLoad() { e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. - e.state = StateInitial + e.state = uint32(StateInitial) // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.keepalive.timer.init(s.Clock(), &e.keepalive.waker) + if snd := e.snd; snd != nil { + snd.resendTimer.init(s.Clock(), &snd.resendWaker) + snd.reorderTimer.init(s.Clock(), &snd.reorderWaker) + snd.probeTimer.init(s.Clock(), &snd.probeWaker) + } e.stack = s e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() - epState := e.origEndpointState + epState := EndpointState(e.origEndpointState) switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption @@ -281,32 +286,12 @@ func (e *endpoint) Resume(s *stack.Stack) { }() case epState == StateClose: e.isPortReserved = false - e.state = StateClose + e.state = uint32(StateClose) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) case epState == StateError: - e.state = StateError + e.state = uint32(StateError) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } - -// saveRecentTSTime is invoked by stateify. -func (e *endpoint) saveRecentTSTime() unixTime { - return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} -} - -// loadRecentTSTime is invoked by stateify. -func (e *endpoint) loadRecentTSTime(unix unixTime) { - e.recentTSTime = time.Unix(unix.second, unix.nano) -} - -// saveLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { - return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()} -} - -// loadLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { - e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2f9fe7ee0..65c86823a 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -65,7 +65,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, f.stack.Clock(), pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index a3d1aa1a3..2fc282e73 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -131,7 +131,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(ep, id, pkt) + p.dispatcher.queuePacket(ep, id, p.stack.Clock(), pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -142,14 +142,14 @@ func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEnd // particular, SYNs addressed to a non-existent connection are rejected by this // means." func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, p.stack.Clock(), pkt) defer s.decRef() if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } - if !s.flagIsSet(header.TCPFlagRst) { + if !s.flags.Contains(header.TCPFlagRst) { replyWithReset(p.stack, s, stack.DefaultTOS, 0) } @@ -181,7 +181,7 @@ func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { // reset has sequence number zero and the ACK field is set to the sum // of the sequence number and segment length of the incoming segment. // The connection remains in the CLOSED state. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } else { flags |= header.TCPFlagAck @@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er case *tcpip.TCPTimeWaitReuseOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) + *v = p.timeWaitReuse p.mu.RUnlock() return nil @@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection. recovery: 0, } - p.dispatcher.init(runtime.GOMAXPROCS(0)) + p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 9e332dcf7..0da4eafaa 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -79,7 +79,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // update will update the RACK related fields when an ACK has been received. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { - rtt := time.Now().Sub(seg.xmitTime) + rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime) tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a @@ -115,7 +115,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { rc.XmitTime = seg.xmitTime rc.EndSequence = endSeq } @@ -174,7 +174,7 @@ func (s *sender) schedulePTO() { } s.rtt.Unlock() - now := time.Now() + now := s.ep.stack.Clock().NowMonotonic() if s.resendTimer.enabled() { if now.Add(pto).After(s.resendTimer.target) { pto = s.resendTimer.target.Sub(now) @@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // been observed RACK uses reo_wnd of zero during loss recovery, in order to // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. -func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { +func (rc *rackControl) updateRACKReorderWindow() { dsackSeen := rc.DSACKSeen snd := rc.snd @@ -352,7 +352,7 @@ func (rc *rackControl) exitRecovery() { // detectLoss marks the segment as lost if the reordering window has elapsed // and the ACK is not received. It will also arm the reorder timer. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5. -func (rc *rackControl) detectLoss(rcvTime time.Time) int { +func (rc *rackControl) detectLoss(rcvTime tcpip.MonotonicTime) int { var timeout time.Duration numLost := 0 for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() { @@ -366,7 +366,7 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true @@ -392,7 +392,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { return nil } - numLost := rc.detectLoss(time.Now()) + numLost := rc.detectLoss(rc.snd.ep.stack.Clock().NowMonotonic()) if numLost == 0 { return nil } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 133371455..9ce8fcae9 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -17,7 +17,6 @@ package tcp import ( "container/heap" "math" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,7 +49,7 @@ type receiver struct { pendingRcvdSegments segmentHeap // Time when the last ack was received. - lastRcvdAckTime time.Time `state:".(unixTime)"` + lastRcvdAckTime tcpip.MonotonicTime } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { @@ -63,7 +62,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - lastRcvdAckTime: time.Now(), + lastRcvdAckTime: ep.stack.Clock().NowMonotonic(), } } @@ -137,9 +136,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow { // If the new window moves the right edge, then update RcvAcc. - r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) + r.RcvAcc = r.RcvNxt.Add(newWnd) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -245,7 +244,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { r.RcvNxt++ // Send ACK immediately. @@ -261,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -296,7 +295,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -325,9 +324,9 @@ func (r *receiver) updateRTT() { // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. r.ep.rcvQueueInfo.rcvQueueMu.Lock() - if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime == (tcpip.MonotonicTime{}) { // New measurement. - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return @@ -336,14 +335,14 @@ func (r *receiver) updateRTT() { r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) + rtt := r.ep.stack.Clock().NowMonotonic().Sub(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } @@ -424,7 +423,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { + if closed && (!s.flags.Contains(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -467,11 +466,11 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { } // Store the time of the last ack. - r.lastRcvdAckTime = time.Now() + r.lastRcvdAckTime = r.ep.stack.Clock().NowMonotonic() // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { - if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { + if segLen > 0 || s.flags.Contains(header.TCPFlagFin) { // We only store the segment if it's within our buffer size limit. // // Only use 75% of the receive buffer queue for out-of-order @@ -539,7 +538,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // // As we do not yet support PAWS, we are being conservative in ignoring // RSTs by default. - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { return false, false } @@ -559,13 +558,12 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { - + if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } // Drop the segment if it does not contain an ACK. - if !s.flagIsSet(header.TCPFlagAck) { + if !s.flags.Contains(header.TCPFlagAck) { return false, false } @@ -574,7 +572,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flags.Contains(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7e5ba6ef7..ca78c96f2 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -17,7 +17,6 @@ package tcp import ( "fmt" "sync/atomic" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -73,9 +72,9 @@ type segment struct { parsedOptions header.TCPOptions options []byte `state:".([]byte)"` hasNewSACKInfo bool - rcvdTime time.Time `state:".(unixTime)"` + rcvdTime tcpip.MonotonicTime // xmitTime is the last transmit time of this segment. - xmitTime time.Time `state:".(unixTime)"` + xmitTime tcpip.MonotonicTime xmitCount uint32 // acked indicates if the segment has already been SACKed. @@ -88,7 +87,7 @@ type segment struct { lost bool } -func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) *segment { netHdr := pkt.Network() s := &segment{ refCnt: 1, @@ -100,17 +99,17 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * } s.data = pkt.Data().ExtractVV().Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() s.dataMemSize = s.data.Size() return s } -func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, } - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() if len(v) != 0 { s.views[0] = v s.data = buffer.NewVectorisedView(len(v), s.views[:1]) @@ -149,16 +148,6 @@ func (s *segment) merge(oth *segment) { oth.dataMemSize = oth.data.Size() } -// flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags header.TCPFlags) bool { - return s.flags&flags != 0 -} - -// flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags header.TCPFlags) bool { - return s.flags&flags == flags -} - // setOwner sets the owning endpoint for this segment. Its required // to be called to ensure memory accounting for receive/send buffer // queues is done properly. @@ -198,10 +187,10 @@ func (s *segment) incRef() { // as the data length plus one for each of the SYN and FIN bits set. func (s *segment) logicalLen() seqnum.Size { l := seqnum.Size(s.data.Size()) - if s.flagIsSet(header.TCPFlagSyn) { + if s.flags.Contains(header.TCPFlagSyn) { l++ } - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { l++ } return l @@ -243,7 +232,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool { return false } - s.options = []byte(s.hdr[header.TCPMinimumSize:]) + s.options = s.hdr[header.TCPMinimumSize:] s.parsedOptions = header.ParseTCPOptions(s.options) if skipChecksumValidation { s.csumValid = true @@ -262,5 +251,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool { // sackBlock returns a header.SACKBlock that represents this segment. func (s *segment) sackBlock() header.SACKBlock { - return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} + return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())} } diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go index 7422d8c02..dcfa80f95 100644 --- a/pkg/tcpip/transport/tcp/segment_state.go +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -15,8 +15,6 @@ package tcp import ( - "time" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -55,23 +53,3 @@ func (s *segment) loadOptions(options []byte) { // allocated so there is no cost here. s.options = options } - -// saveRcvdTime is invoked by stateify. -func (s *segment) saveRcvdTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadRcvdTime is invoked by stateify. -func (s *segment) loadRcvdTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (s *segment) saveXmitTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (s *segment) loadXmitTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 486016fc0..2e6ea06f5 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,10 +40,11 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW } func TestSegmentMerge(t *testing.T) { + var clock faketime.NullClock id := stack.TransportEndpointID{} - seg1 := newOutgoingSegment(id, buffer.NewView(10)) + seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10)) defer seg1.decRef() - seg2 := newOutgoingSegment(id, buffer.NewView(20)) + seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20)) defer seg2.decRef() checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index f43e86677..72d58dcff 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -94,7 +94,7 @@ type sender struct { // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. - firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + firstRetransmittedSegXmitTime tcpip.MonotonicTime // zeroWindowProbing is set if the sender is currently probing // for zero receive window. @@ -169,7 +169,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint SndUna: iss + 1, SndNxt: iss + 1, RTTMeasureSeqNum: iss + 1, - LastSendTime: time.Now(), + LastSendTime: ep.stack.Clock().NowMonotonic(), MaxPayloadSize: maxPayloadSize, MaxSentAck: irs + 1, FastRecovery: stack.TCPFastRecoveryState{ @@ -197,9 +197,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.SndWndScale = uint8(sndWndScale) } - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) + s.resendTimer.init(s.ep.stack.Clock(), &s.resendWaker) + s.reorderTimer.init(s.ep.stack.Clock(), &s.reorderWaker) + s.probeTimer.init(s.ep.stack.Clock(), &s.probeWaker) s.updateMaxPayloadSize(int(ep.route.MTU()), 0) @@ -441,7 +441,7 @@ func (s *sender) retransmitTimerExpired() bool { // timeout since the first retransmission. uto := s.ep.userTimeout - if s.firstRetransmittedSegXmitTime.IsZero() { + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { // We store the original xmitTime of the segment that we are // about to retransmit as the retransmission time. This is // required as by the time the retransmitTimer has expired the @@ -450,7 +450,7 @@ func (s *sender) retransmitTimerExpired() bool { s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime } - elapsed := time.Since(s.firstRetransmittedSegXmitTime) + elapsed := s.ep.stack.Clock().NowMonotonic().Sub(s.firstRetransmittedSegXmitTime) remaining := s.maxRTO if uto != 0 { // Cap to the user specified timeout if one is specified. @@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // 'S2' that meets the following 3 criteria for determinig // loss, the sequence range of one segment of up to SMSS // octects starting with S2 MUST be returned. - if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) { // NextSeg(): // // (1.a) S2 is greater than HighRxt @@ -866,8 +866,8 @@ func (s *sender) enableZeroWindowProbing() { // We piggyback the probing on the retransmit timer with the // current retranmission interval, as we may start probing while // segment retransmissions. - if s.firstRetransmittedSegXmitTime.IsZero() { - s.firstRetransmittedSegXmitTime = time.Now() + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { + s.firstRetransmittedSegXmitTime = s.ep.stack.Clock().NowMonotonic() } s.resendTimer.enable(s.RTO) } @@ -875,7 +875,7 @@ func (s *sender) enableZeroWindowProbing() { func (s *sender) disableZeroWindowProbing() { s.zeroWindowProbing = false s.unackZeroWindowProbes = 0 - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() } @@ -925,7 +925,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && s.ep.stack.Clock().NowMonotonic().Sub(s.LastSendTime) > s.RTO { if s.SndCwnd > InitialCwnd { s.SndCwnd = InitialCwnd } @@ -1024,7 +1024,7 @@ func (s *sender) SetPipe() { if segEnd.LessThan(endSeq) { endSeq = segEnd } - sb := header.SACKBlock{startSeq, endSeq} + sb := header.SACKBlock{Start: startSeq, End: endSeq} // SetPipe(): // // After initializing pipe to zero, the following steps are @@ -1132,7 +1132,7 @@ func (s *sender) isDupAck(seg *segment) bool { // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. - !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && + !seg.flags.Intersects(header.TCPFlagFin|header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). seg.ackNumber == s.SndUna && @@ -1234,7 +1234,7 @@ func checkDSACK(rcvdSeg *segment) bool { func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.updateRTO(s.ep.stack.Clock().NowMonotonic().Sub(s.RTTMeasureTime)) s.RTTMeasureSeqNum = s.SndNxt } @@ -1444,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if s.SndUna == s.SndNxt { s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() s.probeTimer.disable() } @@ -1455,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: // * Step 4: Update RACK reordering window - s.rc.updateRACKReorderWindow(rcvdSeg) + s.rc.updateRACKReorderWindow() // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the @@ -1502,7 +1502,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } - seg.xmitTime = time.Now() + seg.xmitTime = s.ep.stack.Clock().NowMonotonic() seg.xmitCount++ seg.lost = false err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) @@ -1527,7 +1527,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.LastSendTime = time.Now() + s.LastSendTime = s.ep.stack.Clock().NowMonotonic() if seq == s.RTTMeasureSeqNum { s.RTTMeasureTime = s.LastSendTime } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go deleted file mode 100644 index 2f805d8ce..000000000 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// +stateify savable -type unixTime struct { - second int64 - nano int64 -} - -// afterLoad is invoked by stateify. -func (s *sender) afterLoad() { - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) -} - -// saveFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { - return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} -} - -// loadFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { - s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c58361bc1..d6cf786a1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -38,7 +38,7 @@ const ( func setStackRACKPermitted(t *testing.T, c *context.Context) { t.Helper() - opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) + opt := tcpip.TCPRACKLossDetection if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) } @@ -50,7 +50,7 @@ func TestRACKUpdate(t *testing.T) { c := context.New(t, uint32(mtu)) defer c.Cleanup() - var xmitTime time.Time + var xmitTime tcpip.MonotonicTime probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { // Validate that the endpoint Sender.RACKState is what we expect. @@ -79,7 +79,7 @@ func TestRACKUpdate(t *testing.T) { } // Write the data. - xmitTime = time.Now() + xmitTime = c.Stack().Clock().NowMonotonic() var r bytes.Reader r.Reset(data) if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9916182e3..71c4aa85d 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3451,17 +3451,13 @@ loop: for { switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { case *tcpip.ErrWouldBlock: - select { - case <-ch: - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) - } - break loop - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") + <-ch + // Expect the state to be StateError and subsequent Reads to fail with HardError. + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } + break loop case *tcpip.ErrConnectionReset: break loop default: @@ -3472,14 +3468,27 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) + + checkValid := func() []error { + var errors []error + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) + } + return errors } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + for _, err := range checkValid() { + t.Error(err) } } @@ -4259,7 +4268,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(iss), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -5335,7 +5344,7 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next-1)), + checker.TCPSeqNum(next-1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), @@ -5360,12 +5369,7 @@ func TestKeepalive(t *testing.T) { }) checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), + checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { @@ -5507,7 +5511,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(lastPortOffset), + SrcPort: context.TestPort + lastPortOffset, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(context.TestInitialSequenceNumber), @@ -5884,7 +5888,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { r.Reset(data) newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) + tcp = header.IPv4(pkt).Payload() if string(tcp.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } @@ -6073,6 +6077,11 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // complete the connection to test that the large SEQ num // did not change the state from SYN-RCVD. + // Get setup to be notified about connection establishment. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.ReadableEvents) + defer c.WQ.EventUnregister(&we) + // Send ACK to move to ESTABLISHED state. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -6083,32 +6092,12 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) + <-ch newEP, _, err := c.EP.Accept(nil) - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: + if err != nil { t.Fatalf("Accept failed: %s", err) } - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" var r strings.Reader @@ -6118,7 +6107,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + tcpHdr = header.IPv4(pkt).Payload() if string(tcpHdr.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) } @@ -6214,12 +6203,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { RcvWnd: 30000, }) - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) + checkValid := func() []error { + var errors []error + if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) + } + return errors } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) + } + for _, err := range checkValid() { + t.Error(err) + } + if t.Failed() { + t.FailNow() } we, ch := waiter.NewChannelEntry(nil) @@ -6230,19 +6233,62 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { _, _, err = c.EP.Accept(nil) if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + <-ch + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) } } } +func TestListenDropIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + c.Create(-1 /*epRcvBuf*/) + + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1 /*backlog*/); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + initialDropped := stats.DroppedPackets.Value() + + // Send RST, FIN segments, that are expected to be dropped by the listener. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin, + }) + + // To ensure that the RST, FIN sent earlier are indeed received and ignored + // by the listener, send a SYN and wait for the SYN to be ACKd. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + )) + + if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { + t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) + } +} + func TestEndpointBindListenAcceptState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -6375,7 +6421,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Allocate a large enough payload for the test. payloadSize := receiveBufferSize * 2 - b := make([]byte, int(payloadSize)) + b := make([]byte, payloadSize) worker := (c.EP).(interface { StopWork() @@ -6429,7 +6475,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // ack, 1 for the non-zero window p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(uint32(wantAckNum)), + checker.TCPAckNum(wantAckNum), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { @@ -6484,14 +6530,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := uint32(rawEP.TSVal) + tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) + return uint16(rcvWnd >> c.WindowScale) } // Allocate a large array to send to the endpoint. b := make([]byte, receiveBufferSize*48) @@ -6619,19 +6665,16 @@ func TestDelayEnabled(t *testing.T) { defer c.Cleanup() checkDelayOption(t, c, false, false) // Delay is disabled by default. - for _, v := range []struct { - delayEnabled tcpip.TCPDelayEnabled - wantDelayOption bool - }{ - {delayEnabled: false, wantDelayOption: false}, - {delayEnabled: true, wantDelayOption: true}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + for _, delayEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + opt := tcpip.TCPDelayEnabled(delayEnabled) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) + } + checkDelayOption(t, c, opt, delayEnabled) + }) } } @@ -7042,7 +7085,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Receive the SYN-ACK reply. b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) + tcpHdr = header.IPv4(b).Payload() c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders = &context.Headers{ @@ -7443,7 +7486,7 @@ func TestTCPUserTimeout(t *testing.T) { select { case <-notifyCh: case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) } // No packet should be received as the connection should be silently @@ -7467,7 +7510,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), @@ -7545,7 +7588,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: iss, - AckNum: seqnum.Value(c.IRS + 1), + AckNum: c.IRS + 1, RcvWnd: 30000, }) diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 38a335840..5645c772e 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -15,21 +15,29 @@ package tcp import ( + "math" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" ) type timerState int const ( + // The timer is disabled. timerStateDisabled timerState = iota + // The timer is enabled, but the clock timer may be set to an earlier + // expiration time due to a previous orphaned state. timerStateEnabled + // The timer is disabled, but the clock timer is enabled, which means that + // it will cause a spurious wakeup unless the timer is enabled before the + // clock timer fires. timerStateOrphaned ) // timer is a timer implementation that reduces the interactions with the -// runtime timer infrastructure by letting timers run (and potentially +// clock timer infrastructure by letting timers run (and potentially // eventually expire) even if they are stopped. It makes it cheaper to // disable/reenable timers at the expense of spurious wakes. This is useful for // cases when the same timer is disabled/reenabled repeatedly with relatively @@ -39,44 +47,37 @@ const ( // (currently at least 200ms), and get disabled when acks are received, and // reenabled when new pending segments are sent. // -// It is advantageous to avoid interacting with the runtime because it acquires +// It is advantageous to avoid interacting with the clock because it acquires // a global mutex and performs O(log n) operations, where n is the global number // of timers, whenever a timer is enabled or disabled, and may make a syscall. // // This struct is thread-compatible. type timer struct { - // state is the current state of the timer, it can be one of the - // following values: - // disabled - the timer is disabled. - // orphaned - the timer is disabled, but the runtime timer is - // enabled, which means that it will evetually cause a - // spurious wake (unless it gets enabled again before - // then). - // enabled - the timer is enabled, but the runtime timer may be set - // to an earlier expiration time due to a previous - // orphaned state. state timerState + clock tcpip.Clock + // target is the expiration time of the current timer. It is only // meaningful in the enabled state. - target time.Time + target tcpip.MonotonicTime - // runtimeTarget is the expiration time of the runtime timer. It is + // clockTarget is the expiration time of the clock timer. It is // meaningful in the enabled and orphaned states. - runtimeTarget time.Time + clockTarget tcpip.MonotonicTime - // timer is the runtime timer used to wait on. - timer *time.Timer + // timer is the clock timer used to wait on. + timer tcpip.Timer } // init initializes the timer. Once it expires, it the given waker will be // asserted. -func (t *timer) init(w *sleep.Waker) { +func (t *timer) init(clock tcpip.Clock, w *sleep.Waker) { t.state = timerStateDisabled + t.clock = clock - // Initialize a runtime timer that will assert the waker, then + // Initialize a clock timer that will assert the waker, then // immediately stop it. - t.timer = time.AfterFunc(time.Hour, func() { + t.timer = t.clock.AfterFunc(math.MaxInt64, func() { w.Assert() }) t.timer.Stop() @@ -106,9 +107,9 @@ func (t *timer) checkExpiration() bool { // The timer is enabled, but it may have expired early. Check if that's // the case, and if so, reset the runtime timer to the correct time. - now := time.Now() + now := t.clock.NowMonotonic() if now.Before(t.target) { - t.runtimeTarget = t.target + t.clockTarget = t.target t.timer.Reset(t.target.Sub(now)) return false } @@ -134,11 +135,11 @@ func (t *timer) enabled() bool { // enable enables the timer, programming the runtime timer if necessary. func (t *timer) enable(d time.Duration) { - t.target = time.Now().Add(d) + t.target = t.clock.NowMonotonic().Add(d) // Check if we need to set the runtime timer. - if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { - t.runtimeTarget = t.target + if t.state == timerStateDisabled || t.target.Before(t.clockTarget) { + t.clockTarget = t.target t.timer.Reset(d) } diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go index dbd6dff54..479752de7 100644 --- a/pkg/tcpip/transport/tcp/timer_test.go +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) func TestCleanup(t *testing.T) { @@ -27,9 +28,11 @@ func TestCleanup(t *testing.T) { isAssertedTimeoutSeconds = timerDurationSeconds + 1 ) + clock := faketime.NewManualClock() + tmr := timer{} w := sleep.Waker{} - tmr.init(&w) + tmr.init(clock, &w) tmr.enable(timerDurationSeconds * time.Second) tmr.cleanup() @@ -39,7 +42,7 @@ func TestCleanup(t *testing.T) { // The waker should not be asserted. for i := 0; i < isAssertedTimeoutSeconds; i++ { - time.Sleep(time.Second) + clock.Advance(time.Second) if w.IsAsserted() { t.Fatalf("waker asserted unexpectedly") } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index dd5c910ae..cdc344ab7 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -49,6 +49,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f7dd50d35..def9d7186 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -17,6 +17,7 @@ package udp import ( "io" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -34,25 +35,26 @@ type udpPacket struct { destinationAddress tcpip.FullAddress packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` // tos stores either the receiveTOS or receiveTClass value. tos uint8 } // EndpointState represents the state of a UDP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - StateInitial EndpointState = iota + _ EndpointState = iota + StateInitial StateBound StateConnected StateClosed ) -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (s EndpointState) String() string { switch s { case StateInitial: @@ -98,7 +100,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState + state uint32 route *stack.Route `state:"manual"` dstPort uint16 ttl uint8 @@ -176,7 +178,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), - state: StateInitial, + state: uint32(StateInitial), uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) @@ -204,15 +206,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState() returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } -// UniqueID implements stack.TransportEndpoint.UniqueID. +// UniqueID implements stack.TransportEndpoint. func (e *endpoint) UniqueID() uint64 { return e.uniqueID } @@ -226,14 +228,14 @@ func (e *endpoint) LastError() tcpip.Error { return err } -// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +// UpdateLastError implements tcpip.SocketOptionsHandler. func (e *endpoint) UpdateLastError(err tcpip.Error) { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() } -// Abort implements stack.TransportEndpoint.Abort. +// Abort implements stack.TransportEndpoint. func (e *endpoint) Abort() { e.Close() } @@ -289,10 +291,10 @@ func (e *endpoint) Close() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +// ModerateRecvBuf implements tcpip.Endpoint. +func (*endpoint) ModerateRecvBuf(int) {} -// Read implements tcpip.Endpoint.Read. +// Read implements tcpip.Endpoint. func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err @@ -320,7 +322,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), } if e.ops.GetReceiveTOS() { cm.HasTOS = true @@ -581,21 +583,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return int64(len(v)), nil } -// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +// OnReuseAddressSet implements tcpip.SocketOptionsHandler. func (e *endpoint) OnReuseAddressSet(v bool) { e.mu.Lock() e.portFlags.MostRecent = v e.mu.Unlock() } -// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +// OnReusePortSet implements tcpip.SocketOptionsHandler. func (e *endpoint) OnReusePortSet(v bool) { e.mu.Lock() e.portFlags.LoadBalanced = v e.mu.Unlock() } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: @@ -629,11 +631,14 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return nil } +var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) + +// HasNIC implements tcpip.SocketOptionsHandler. func (e *endpoint) HasNIC(id int32) bool { - return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) + return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +// SetSockOpt implements tcpip.Endpoint. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.MulticastInterfaceOption: @@ -749,7 +754,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.IPv4TOSOption: @@ -795,14 +800,14 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +// GetSockOpt implements tcpip.Endpoint. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ - e.multicastNICID, - e.multicastAddr, + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, } e.mu.Unlock() @@ -872,7 +877,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres return unwrapped, netProto, nil } -// Disconnect implements tcpip.Endpoint.Disconnect. +// Disconnect implements tcpip.Endpoint. func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -1075,7 +1080,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, nil /* testPort */) + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */) if err != nil { return id, bindToDevice, err } @@ -1301,12 +1306,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.LocalAddress, - Port: header.UDP(hdr).DestinationPort(), + Port: hdr.DestinationPort(), }, data: pkt.Data().ExtractVV(), } @@ -1328,7 +1333,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB packet.packetInfo.LocalAddr = localAddr packet.packetInfo.DestinationAddr = localAddr packet.packetInfo.NIC = pkt.NICID - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() @@ -1386,7 +1391,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB } } -// State implements tcpip.Endpoint.State. +// State implements tcpip.Endpoint. func (e *endpoint) State() uint32 { return uint32(e.EndpointState()) } @@ -1405,19 +1410,19 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } -// Wait implements tcpip.Endpoint.Wait. +// Wait implements tcpip.Endpoint. func (*endpoint) Wait() {} func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) } -// SetOwner implements tcpip.Endpoint.SetOwner. +// SetOwner implements tcpip.Endpoint. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// SocketOptions implements tcpip.Endpoint.SocketOptions. +// SocketOptions implements tcpip.Endpoint. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4aba68b21..1f638c3f6 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,26 +15,38 @@ package udp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *udpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *udpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves udpPacket.data field. -func (u *udpPacket) saveData() buffer.VectorisedView { - // We cannot save u.data directly as u.data.views may alias to u.views, +func (p *udpPacket) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). - return u.data.Clone(nil) + return p.data.Clone(nil) } // loadData loads udpPacket.data field. -func (u *udpPacket) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization +func (p *udpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point - // of utilizing u.views for data.views. - u.data = data + // of utilizing p.views for data.views. + p.data = data } // afterLoad is invoked by stateify. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 705ad1f64..7c357cb09 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = StateConnected + ep.state = uint32(StateConnected) ep.rcvMu.Lock() ep.rcvReady = true diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index dc2e3f493..4008cacf2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,17 +16,16 @@ package udp_test import ( "bytes" - "context" "fmt" "io/ioutil" "math/rand" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -298,16 +297,18 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: true, - }) + return newDualTestContextWithHandleLocal(t, mtu, true) } -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { +func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { t.Helper() + options := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: handleLocal, + Clock: &faketime.NullClock{}, + } s := stack.New(options) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -378,9 +379,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { c.t.Fatalf("Packet wasn't written out") return nil @@ -534,7 +533,9 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -606,7 +607,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe case <-ch: res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - case <-time.After(300 * time.Millisecond): + default: if packetShouldBeDropped { return // expected to time out } @@ -820,11 +821,7 @@ func TestV4ReadSelfSource(t *testing.T) { {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: tt.handleLocal, - }) + c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal) defer c.cleanup() c.createEndpointForFlow(unicastV4) @@ -1034,17 +1031,17 @@ func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, che payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP + var udpH header.UDP if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) + udpH = header.IPv4(b).Payload() } else { - udp = header.UDP(header.IPv6(b).Payload()) + udpH = header.IPv6(b).Payload() } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + if !bytes.Equal(payload, udpH.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload) } - return udp.SourcePort() + return udpH.SourcePort() } func testDualWrite(c *testContext) uint16 { @@ -1198,7 +1195,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { r.Reset(payload) n, err := c.ep.Write(&r, writeOpts) if err != nil { - c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + c.t.Fatalf("c.ep.Write(...) = %s, want nil", err) } if got, want := n, int64(len(payload)); got != want { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) @@ -1462,7 +1459,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { name: "IPv4 unicast", proto: header.IPv4ProtocolNumber, flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, }, { name: "IPv4 multicast", @@ -1474,7 +1471,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, }, { name: "IPv4 broadcast", @@ -1486,13 +1483,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, }, { name: "IPv6 unicast", proto: header.IPv6ProtocolNumber, flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, }, { name: "IPv6 multicast", @@ -1504,7 +1501,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, }, } @@ -1614,6 +1611,7 @@ func TestTTL(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{p}, + Clock: &faketime.NullClock{}, }) ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) wantTTL = ep.DefaultTTL() @@ -1759,7 +1757,7 @@ func TestReceiveTosTClass(t *testing.T) { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } - want := true + const want = true optionSetter(want) got := optionGetter() @@ -1889,18 +1887,14 @@ func TestV4UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -1987,18 +1981,14 @@ func TestV6UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -2115,8 +2105,8 @@ func TestShortHeader(t *testing.T) { Data: buf.ToVectorisedView(), })) - if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) } } @@ -2124,25 +2114,27 @@ func TestShortHeader(t *testing.T) { // global and endpoint stats are incremented. func TestBadChecksumErrors(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV6} { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } + }) } } @@ -2484,6 +2476,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD index a789c246e..7ff13cf12 100644 --- a/pkg/test/testutil/BUILD +++ b/pkg/test/testutil/BUILD @@ -12,6 +12,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/sentry/watchdog", "//pkg/sync", "//runsc/config", "//runsc/specutils", diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 663c83679..f6a3e34c7 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -42,6 +42,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" @@ -184,6 +185,7 @@ func TestConfig(t *testing.T) *config.Config { conf.Network = config.NetworkNone conf.Strace = true conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true + conf.WatchdogAction = watchdog.Panic return conf } diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index 0e9a829f6..0ef635a2f 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -27,6 +27,7 @@ import ( "os" "reflect" "runtime" + "time" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" @@ -458,29 +459,45 @@ func (s *Server) StartHandling(client *unet.Socket) { // No new requests should be initiated after calling Stop. Existing clients // will be closed after completing any pending RPCs. This method will block // until all clients have disconnected. -func (s *Server) Stop() { - // Wait for all outstanding requests. - defer s.wg.Wait() - +// +// timeout is the time for clients to complete ongoing RPCs. +func (s *Server) Stop(timeout time.Duration) { // Call any Stop callbacks. for _, stopper := range s.stoppers { stopper.Stop() } - // Close all known clients. - s.mu.Lock() - defer s.mu.Unlock() - for client, state := range s.clients { - switch state { - case idle: - // Close connection now. - client.Close() - s.clients[client] = closed - case processing: - // Request close when done. - s.clients[client] = closeRequested + done := make(chan bool, 1) + go func() { + if timeout != 0 { + timer := time.NewTicker(timeout) + defer timer.Stop() + select { + case <-done: + return + case <-timer.C: + } } - } + + // Close all known clients. + s.mu.Lock() + defer s.mu.Unlock() + for client, state := range s.clients { + switch state { + case idle: + // Close connection now. + client.Close() + s.clients[client] = closed + case processing: + // Request close when done. + s.clients[client] = closeRequested + } + } + }() + + // Wait for all outstanding requests. + s.wg.Wait() + done <- true } // Client is a urpc client. diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD index 3dba36f12..54674ee88 100644 --- a/pkg/usermem/BUILD +++ b/pkg/usermem/BUILD @@ -7,12 +7,14 @@ go_library( srcs = [ "bytes_io.go", "bytes_io_unsafe.go", + "marshal.go", "usermem.go", ], visibility = ["//:sandbox"], deps = [ "//pkg/atomicbitops", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/gohacks", "//pkg/hostarch", "//pkg/safemem", @@ -29,6 +31,7 @@ go_test( library = ":usermem", deps = [ "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/safemem", "//pkg/syserror", diff --git a/pkg/usermem/bytes_io.go b/pkg/usermem/bytes_io.go index 3da3c0294..4c97b9136 100644 --- a/pkg/usermem/bytes_io.go +++ b/pkg/usermem/bytes_io.go @@ -16,6 +16,7 @@ package usermem import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" @@ -51,7 +52,7 @@ func (b *BytesIO) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, op // ZeroOut implements IO.ZeroOut. func (b *BytesIO) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) { if toZero > int64(maxInt) { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } rngN, rngErr := b.rangeCheck(addr, int(toZero)) if rngN == 0 { @@ -89,7 +90,7 @@ func (b *BytesIO) rangeCheck(addr hostarch.Addr, length int) (int, error) { return 0, nil } if length < 0 { - return 0, syserror.EINVAL + return 0, linuxerr.EINVAL } max := hostarch.Addr(len(b.Bytes)) if addr >= max { diff --git a/pkg/usermem/marshal.go b/pkg/usermem/marshal.go new file mode 100644 index 000000000..5b5a662dc --- /dev/null +++ b/pkg/usermem/marshal.go @@ -0,0 +1,43 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" +) + +// IOCopyContext wraps an object implementing hostarch.IO to implement +// marshal.CopyContext. +type IOCopyContext struct { + Ctx context.Context + IO IO + Opts IOOpts +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) { + return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) { + return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) +} diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 0d6d25e50..483e64c1e 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -22,11 +22,11 @@ import ( "strconv" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" - - "gvisor.dev/gvisor/pkg/hostarch" ) // IO provides access to the contents of a virtual memory space. @@ -382,7 +382,7 @@ func CopyInt32StringsInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSe // Parse a single value. val, err := strconv.ParseInt(string(buf[i:nextI]), 10, 32) if err != nil { - return int64(i), syserror.EINVAL + return int64(i), linuxerr.EINVAL } dsts[j] = int32(val) @@ -398,7 +398,7 @@ func CopyInt32StringsInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSe return int64(i), cperr } if j == 0 { - return int64(i), syserror.EINVAL + return int64(i), linuxerr.EINVAL } return int64(i), nil } diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go index 9b697b593..4fedb57a5 100644 --- a/pkg/usermem/usermem_test.go +++ b/pkg/usermem/usermem_test.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/syserror" @@ -272,8 +273,8 @@ func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { src := BytesIOSequence([]byte(s)) initial := []int32{1, 2} dsts := append([]int32(nil), initial...) - if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL { - t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL) + if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) { + t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, linuxerr.EINVAL) } if !reflect.DeepEqual(dsts, initial) { t.Errorf("dsts: got %v, wanted %v", dsts, initial) |