diff options
Diffstat (limited to 'pkg')
190 files changed, 3868 insertions, 3167 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 29ead20d0..eb004a7f6 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -21,7 +21,6 @@ go_library( "epoll.go", "epoll_amd64.go", "epoll_arm64.go", - "errors.go", "errqueue.go", "eventfd.go", "exec.go", @@ -79,6 +78,7 @@ go_library( "//pkg/abi", "//pkg/bits", "//pkg/context", + "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", ], diff --git a/pkg/abi/linux/errno/BUILD b/pkg/abi/linux/errno/BUILD new file mode 100644 index 000000000..d003342d5 --- /dev/null +++ b/pkg/abi/linux/errno/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "errno", + srcs = ["errno.go"], + visibility = ["//visibility:public"], +) diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errno/errno.go index b08b2687e..5a09c6605 100644 --- a/pkg/abi/linux/errors.go +++ b/pkg/abi/linux/errno/errno.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package linux +// Package errno holds errno codes for abi/linux. +package errno // Errno represents a Linux errno value. -type Errno int +type Errno uint32 // Errno values from include/uapi/asm-generic/errno-base.h. const ( @@ -60,8 +61,8 @@ const ( ENOLCK ENOSYS ENOTEMPTY - ELOOP //40 - _ // Skip for EWOULDBLOCK = EAGAIN + ELOOP // 40 + _ // Skip for EWOULDBLOCK = EAGAIN. ENOMSG //42 EIDRM ECHRNG @@ -78,13 +79,13 @@ const ( ENOANO EBADRQC EBADSLT - _ // Skip for EDEADLOCK = EDEADLK + _ // Skip for EDEADLOCK = EDEADLK. EBFONT ENOSTR // 60 ENODATA ETIME ENOSR - ENONET + _ // Skip for ENOENT = ENONET. ENOPKG EREMOTE ENOLINK @@ -160,4 +161,5 @@ const ( const ( EWOULDBLOCK = EAGAIN EDEADLOCK = EDEADLK + ENONET = ENOENT ) diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index b088b207c..f8c0e891e 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -41,7 +41,6 @@ const ( // IP6T_ORIGINAL_DST is the ip6tables SOL_IPV6 socket option. Corresponds to // the value in include/uapi/linux/netfilter_ipv6/ip6_tables.h. -// TODO(gvisor.dev/issue/3549): Support IPv6 original destination. const IP6T_ORIGINAL_DST = 80 // IP6TReplace is the argument for the IP6T_SO_SET_REPLACE sockopt. It diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index 6ca57ffbb..06a4c6401 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -165,7 +166,7 @@ const ( SIG_IGN = 1 ) -// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h +// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h. const ( SA_NOCLDSTOP = 0x00000001 SA_NOCLDWAIT = 0x00000002 @@ -179,21 +180,17 @@ const ( SA_ONESHOT = SA_RESETHAND ) -// Signal info types. +// Signal stack flags for signalstack(2), from include/uapi/linux/signal.h. const ( - SI_MASK = 0xffff0000 - SI_KILL = 0 << 16 - SI_TIMER = 1 << 16 - SI_POLL = 2 << 16 - SI_FAULT = 3 << 16 - SI_CHLD = 4 << 16 - SI_RT = 5 << 16 - SI_MESGQ = 6 << 16 - SI_SYS = 7 << 16 + SS_ONSTACK = 1 + SS_DISABLE = 2 ) // SIGPOLL si_codes. const ( + // SI_POLL is defined as __SI_POLL in Linux 2.6. + SI_POLL = 2 << 16 + // POLL_IN indicates that data input available. POLL_IN = SI_POLL | 1 @@ -213,6 +210,75 @@ const ( POLL_HUP = SI_POLL | 6 ) +// Possible values for si_code. +const ( + // SI_USER is sent by kill, sigsend, raise. + SI_USER = 0 + + // SI_KERNEL is sent by the kernel from somewhere. + SI_KERNEL = 0x80 + + // SI_QUEUE is sent by sigqueue. + SI_QUEUE = -1 + + // SI_TIMER is sent by timer expiration. + SI_TIMER = -2 + + // SI_MESGQ is sent by real time mesq state change. + SI_MESGQ = -3 + + // SI_ASYNCIO is sent by AIO completion. + SI_ASYNCIO = -4 + + // SI_SIGIO is sent by queued SIGIO. + SI_SIGIO = -5 + + // SI_TKILL is sent by tkill system call. + SI_TKILL = -6 + + // SI_DETHREAD is sent by execve() killing subsidiary threads. + SI_DETHREAD = -7 + + // SI_ASYNCNL is sent by glibc async name lookup completion. + SI_ASYNCNL = -60 +) + +// CLD_* codes are only meaningful for SIGCHLD. +const ( + // CLD_EXITED indicates that a task exited. + CLD_EXITED = 1 + + // CLD_KILLED indicates that a task was killed by a signal. + CLD_KILLED = 2 + + // CLD_DUMPED indicates that a task was killed by a signal and then dumped + // core. + CLD_DUMPED = 3 + + // CLD_TRAPPED indicates that a task was stopped by ptrace. + CLD_TRAPPED = 4 + + // CLD_STOPPED indicates that a thread group completed a group stop. + CLD_STOPPED = 5 + + // CLD_CONTINUED indicates that a group-stopped thread group was continued. + CLD_CONTINUED = 6 +) + +// SYS_* codes are only meaningful for SIGSYS. +const ( + // SYS_SECCOMP indicates that a signal originates from seccomp. + SYS_SECCOMP = 1 +) + +// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify. +const ( + SIGEV_SIGNAL = 0 + SIGEV_NONE = 1 + SIGEV_THREAD = 2 + SIGEV_THREAD_ID = 4 +) + // Sigevent represents struct sigevent. // // +marshal @@ -227,10 +293,252 @@ type Sigevent struct { UnRemainder [44]byte } -// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify. -const ( - SIGEV_SIGNAL = 0 - SIGEV_NONE = 1 - SIGEV_THREAD = 2 - SIGEV_THREAD_ID = 4 -) +// SigAction represents struct sigaction. +// +// +marshal +// +stateify savable +type SigAction struct { + Handler uint64 + Flags uint64 + Restorer uint64 + Mask SignalSet +} + +// SignalStack represents information about a user stack, and is equivalent to +// stack_t. +// +// +marshal +// +stateify savable +type SignalStack struct { + Addr uint64 + Flags uint32 + _ uint32 + Size uint64 +} + +// Contains checks if the stack pointer is within this stack. +func (s *SignalStack) Contains(sp hostarch.Addr) bool { + return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size) +} + +// Top returns the stack's top address. +func (s *SignalStack) Top() hostarch.Addr { + return hostarch.Addr(s.Addr + s.Size) +} + +// IsEnabled returns true iff this signal stack is marked as enabled. +func (s *SignalStack) IsEnabled() bool { + return s.Flags&SS_DISABLE == 0 +} + +// SignalInfo represents information about a signal being delivered, and is +// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h). +// +// +marshal +// +stateify savable +type SignalInfo struct { + Signo int32 // Signal number + Errno int32 // Errno value + Code int32 // Signal code + _ uint32 + + // struct siginfo::_sifields is a union. In SignalInfo, fields in the union + // are accessed through methods. + // + // For reference, here is the definition of _sifields: (_sigfault._trapno, + // which does not exist on x86, omitted for clarity) + // + // union { + // int _pad[SI_PAD_SIZE]; + // + // /* kill() */ + // struct { + // __kernel_pid_t _pid; /* sender's pid */ + // __ARCH_SI_UID_T _uid; /* sender's uid */ + // } _kill; + // + // /* POSIX.1b timers */ + // struct { + // __kernel_timer_t _tid; /* timer id */ + // int _overrun; /* overrun count */ + // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)]; + // sigval_t _sigval; /* same as below */ + // int _sys_private; /* not to be passed to user */ + // } _timer; + // + // /* POSIX.1b signals */ + // struct { + // __kernel_pid_t _pid; /* sender's pid */ + // __ARCH_SI_UID_T _uid; /* sender's uid */ + // sigval_t _sigval; + // } _rt; + // + // /* SIGCHLD */ + // struct { + // __kernel_pid_t _pid; /* which child */ + // __ARCH_SI_UID_T _uid; /* sender's uid */ + // int _status; /* exit code */ + // __ARCH_SI_CLOCK_T _utime; + // __ARCH_SI_CLOCK_T _stime; + // } _sigchld; + // + // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */ + // struct { + // void *_addr; /* faulting insn/memory ref. */ + // short _addr_lsb; /* LSB of the reported address */ + // } _sigfault; + // + // /* SIGPOLL */ + // struct { + // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */ + // int _fd; + // } _sigpoll; + // + // /* SIGSYS */ + // struct { + // void *_call_addr; /* calling user insn */ + // int _syscall; /* triggering system call number */ + // unsigned int _arch; /* AUDIT_ARCH_* of syscall */ + // } _sigsys; + // } _sifields; + // + // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128 + // bytes. + Fields [128 - 16]byte +} + +// FixSignalCodeForUser fixes up si_code. +// +// The si_code we get from Linux may contain the kernel-specific code in the +// top 16 bits if it's positive (e.g., from ptrace). Linux's +// copy_siginfo_to_user does +// err |= __put_user((short)from->si_code, &to->si_code); +// to mask out those bits and we need to do the same. +func (s *SignalInfo) FixSignalCodeForUser() { + if s.Code > 0 { + s.Code &= 0x0000ffff + } +} + +// PID returns the si_pid field. +func (s *SignalInfo) PID() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4])) +} + +// SetPID mutates the si_pid field. +func (s *SignalInfo) SetPID(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) +} + +// UID returns the si_uid field. +func (s *SignalInfo) UID() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) +} + +// SetUID mutates the si_uid field. +func (s *SignalInfo) SetUID(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) +} + +// Sigval returns the sigval field, which is aliased to both si_int and si_ptr. +func (s *SignalInfo) Sigval() uint64 { + return hostarch.ByteOrder.Uint64(s.Fields[8:16]) +} + +// SetSigval mutates the sigval field. +func (s *SignalInfo) SetSigval(val uint64) { + hostarch.ByteOrder.PutUint64(s.Fields[8:16], val) +} + +// TimerID returns the si_timerid field. +func (s *SignalInfo) TimerID() TimerID { + return TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4])) +} + +// SetTimerID sets the si_timerid field. +func (s *SignalInfo) SetTimerID(val TimerID) { + hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) +} + +// Overrun returns the si_overrun field. +func (s *SignalInfo) Overrun() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) +} + +// SetOverrun sets the si_overrun field. +func (s *SignalInfo) SetOverrun(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) +} + +// Addr returns the si_addr field. +func (s *SignalInfo) Addr() uint64 { + return hostarch.ByteOrder.Uint64(s.Fields[0:8]) +} + +// SetAddr sets the si_addr field. +func (s *SignalInfo) SetAddr(val uint64) { + hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) +} + +// Status returns the si_status field. +func (s *SignalInfo) Status() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) +} + +// SetStatus mutates the si_status field. +func (s *SignalInfo) SetStatus(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) +} + +// CallAddr returns the si_call_addr field. +func (s *SignalInfo) CallAddr() uint64 { + return hostarch.ByteOrder.Uint64(s.Fields[0:8]) +} + +// SetCallAddr mutates the si_call_addr field. +func (s *SignalInfo) SetCallAddr(val uint64) { + hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) +} + +// Syscall returns the si_syscall field. +func (s *SignalInfo) Syscall() int32 { + return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) +} + +// SetSyscall mutates the si_syscall field. +func (s *SignalInfo) SetSyscall(val int32) { + hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) +} + +// Arch returns the si_arch field. +func (s *SignalInfo) Arch() uint32 { + return hostarch.ByteOrder.Uint32(s.Fields[12:16]) +} + +// SetArch mutates the si_arch field. +func (s *SignalInfo) SetArch(val uint32) { + hostarch.ByteOrder.PutUint32(s.Fields[12:16], val) +} + +// Band returns the si_band field. +func (s *SignalInfo) Band() int64 { + return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8])) +} + +// SetBand mutates the si_band field. +func (s *SignalInfo) SetBand(val int64) { + // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. + // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is + // `int`. See siginfo.h. + hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) +} + +// FD returns the si_fd field. +func (s *SignalInfo) FD() uint32 { + return hostarch.ByteOrder.Uint32(s.Fields[8:12]) +} + +// SetFD mutates the si_fd field. +func (s *SignalInfo) SetFD(val uint32) { + hostarch.ByteOrder.PutUint32(s.Fields[8:12], val) +} diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go index 8ef837f27..1fa7a4f4f 100644 --- a/pkg/abi/linux/xattr.go +++ b/pkg/abi/linux/xattr.go @@ -23,6 +23,12 @@ const ( XATTR_CREATE = 1 XATTR_REPLACE = 2 + XATTR_SECURITY_PREFIX = "security." + XATTR_SECURITY_PREFIX_LEN = len(XATTR_SECURITY_PREFIX) + + XATTR_SYSTEM_PREFIX = "system." + XATTR_SYSTEM_PREFIX_LEN = len(XATTR_SYSTEM_PREFIX) + XATTR_TRUSTED_PREFIX = "trusted." XATTR_TRUSTED_PREFIX_LEN = len(XATTR_TRUSTED_PREFIX) diff --git a/pkg/errors/BUILD b/pkg/errors/BUILD new file mode 100644 index 000000000..36ea5da0c --- /dev/null +++ b/pkg/errors/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "errors", + srcs = ["errors.go"], + visibility = ["//:sandbox"], + deps = ["//pkg/abi/linux/errno"], +) diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go new file mode 100644 index 000000000..7984a770e --- /dev/null +++ b/pkg/errors/errors.go @@ -0,0 +1,40 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package errors holds the standardized error definition for gVisor. +package errors + +import ( + "gvisor.dev/gvisor/pkg/abi/linux/errno" +) + +// Error represents a syscall errno with a descriptive message. +type Error struct { + errno errno.Errno + message string +} + +// New creates a new *Error. +func New(err errno.Errno, message string) *Error { + return &Error{ + errno: err, + message: message, + } +} + +// Error implements error.Error. +func (e *Error) Error() string { return e.message } + +// Errno returns the underlying errno.Errno value. +func (e *Error) Errno() errno.Errno { return e.errno } diff --git a/pkg/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD index c5abbd34f..8afc9688c 100644 --- a/pkg/linuxerr/BUILD +++ b/pkg/errors/linuxerr/BUILD @@ -6,7 +6,10 @@ go_library( name = "linuxerr", srcs = ["linuxerr.go"], visibility = ["//visibility:public"], - deps = ["//pkg/abi/linux"], + deps = [ + "//pkg/abi/linux/errno", + "//pkg/errors", + ], ) go_test( @@ -14,6 +17,8 @@ go_test( srcs = ["linuxerr_test.go"], deps = [ ":linuxerr", + "//pkg/abi/linux/errno", + "//pkg/errors", "//pkg/syserror", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go new file mode 100644 index 000000000..bbdcdecd0 --- /dev/null +++ b/pkg/errors/linuxerr/linuxerr.go @@ -0,0 +1,327 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"),; +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package linuxerr contains syscall error codes exported as an error interface +// pointers. This allows for fast comparison and return operations comperable +// to unix.Errno constants. +package linuxerr + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors" +) + +const maxErrno uint32 = errno.EHWPOISON + 1 + +var ( + NOERROR = errors.New(errno.NOERRNO, "not an error") + EPERM = errors.New(errno.EPERM, "operation not permitted") + ENOENT = errors.New(errno.ENOENT, "no such file or directory") + ESRCH = errors.New(errno.ESRCH, "no such process") + EINTR = errors.New(errno.EINTR, "interrupted system call") + EIO = errors.New(errno.EIO, "I/O error") + ENXIO = errors.New(errno.ENXIO, "no such device or address") + E2BIG = errors.New(errno.E2BIG, "argument list too long") + ENOEXEC = errors.New(errno.ENOEXEC, "exec format error") + EBADF = errors.New(errno.EBADF, "bad file number") + ECHILD = errors.New(errno.ECHILD, "no child processes") + EAGAIN = errors.New(errno.EAGAIN, "try again") + ENOMEM = errors.New(errno.ENOMEM, "out of memory") + EACCES = errors.New(errno.EACCES, "permission denied") + EFAULT = errors.New(errno.EFAULT, "bad address") + ENOTBLK = errors.New(errno.ENOTBLK, "block device required") + EBUSY = errors.New(errno.EBUSY, "device or resource busy") + EEXIST = errors.New(errno.EEXIST, "file exists") + EXDEV = errors.New(errno.EXDEV, "cross-device link") + ENODEV = errors.New(errno.ENODEV, "no such device") + ENOTDIR = errors.New(errno.ENOTDIR, "not a directory") + EISDIR = errors.New(errno.EISDIR, "is a directory") + EINVAL = errors.New(errno.EINVAL, "invalid argument") + ENFILE = errors.New(errno.ENFILE, "file table overflow") + EMFILE = errors.New(errno.EMFILE, "too many open files") + ENOTTY = errors.New(errno.ENOTTY, "not a typewriter") + ETXTBSY = errors.New(errno.ETXTBSY, "text file busy") + EFBIG = errors.New(errno.EFBIG, "file too large") + ENOSPC = errors.New(errno.ENOSPC, "no space left on device") + ESPIPE = errors.New(errno.ESPIPE, "illegal seek") + EROFS = errors.New(errno.EROFS, "read-only file system") + EMLINK = errors.New(errno.EMLINK, "too many links") + EPIPE = errors.New(errno.EPIPE, "broken pipe") + EDOM = errors.New(errno.EDOM, "math argument out of domain of func") + ERANGE = errors.New(errno.ERANGE, "math result not representable") + + // Errno values from include/uapi/asm-generic/errno.h. + EDEADLK = errors.New(errno.EDEADLK, "resource deadlock would occur") + ENAMETOOLONG = errors.New(errno.ENAMETOOLONG, "file name too long") + ENOLCK = errors.New(errno.ENOLCK, "no record locks available") + ENOSYS = errors.New(errno.ENOSYS, "invalid system call number") + ENOTEMPTY = errors.New(errno.ENOTEMPTY, "directory not empty") + ELOOP = errors.New(errno.ELOOP, "too many symbolic links encountered") + ENOMSG = errors.New(errno.ENOMSG, "no message of desired type") + EIDRM = errors.New(errno.EIDRM, "identifier removed") + ECHRNG = errors.New(errno.ECHRNG, "channel number out of range") + EL2NSYNC = errors.New(errno.EL2NSYNC, "level 2 not synchronized") + EL3HLT = errors.New(errno.EL3HLT, "level 3 halted") + EL3RST = errors.New(errno.EL3RST, "level 3 reset") + ELNRNG = errors.New(errno.ELNRNG, "link number out of range") + EUNATCH = errors.New(errno.EUNATCH, "protocol driver not attached") + ENOCSI = errors.New(errno.ENOCSI, "no CSI structure available") + EL2HLT = errors.New(errno.EL2HLT, "level 2 halted") + EBADE = errors.New(errno.EBADE, "invalid exchange") + EBADR = errors.New(errno.EBADR, "invalid request descriptor") + EXFULL = errors.New(errno.EXFULL, "exchange full") + ENOANO = errors.New(errno.ENOANO, "no anode") + EBADRQC = errors.New(errno.EBADRQC, "invalid request code") + EBADSLT = errors.New(errno.EBADSLT, "invalid slot") + EBFONT = errors.New(errno.EBFONT, "bad font file format") + ENOSTR = errors.New(errno.ENOSTR, "device not a stream") + ENODATA = errors.New(errno.ENODATA, "no data available") + ETIME = errors.New(errno.ETIME, "timer expired") + ENOSR = errors.New(errno.ENOSR, "out of streams resources") + ENOPKG = errors.New(errno.ENOPKG, "package not installed") + EREMOTE = errors.New(errno.EREMOTE, "object is remote") + ENOLINK = errors.New(errno.ENOLINK, "link has been severed") + EADV = errors.New(errno.EADV, "advertise error") + ESRMNT = errors.New(errno.ESRMNT, "srmount error") + ECOMM = errors.New(errno.ECOMM, "communication error on send") + EPROTO = errors.New(errno.EPROTO, "protocol error") + EMULTIHOP = errors.New(errno.EMULTIHOP, "multihop attempted") + EDOTDOT = errors.New(errno.EDOTDOT, "RFS specific error") + EBADMSG = errors.New(errno.EBADMSG, "not a data message") + EOVERFLOW = errors.New(errno.EOVERFLOW, "value too large for defined data type") + ENOTUNIQ = errors.New(errno.ENOTUNIQ, "name not unique on network") + EBADFD = errors.New(errno.EBADFD, "file descriptor in bad state") + EREMCHG = errors.New(errno.EREMCHG, "remote address changed") + ELIBACC = errors.New(errno.ELIBACC, "can not access a needed shared library") + ELIBBAD = errors.New(errno.ELIBBAD, "accessing a corrupted shared library") + ELIBSCN = errors.New(errno.ELIBSCN, ".lib section in a.out corrupted") + ELIBMAX = errors.New(errno.ELIBMAX, "attempting to link in too many shared libraries") + ELIBEXEC = errors.New(errno.ELIBEXEC, "cannot exec a shared library directly") + EILSEQ = errors.New(errno.EILSEQ, "illegal byte sequence") + ERESTART = errors.New(errno.ERESTART, "interrupted system call should be restarted") + ESTRPIPE = errors.New(errno.ESTRPIPE, "streams pipe error") + EUSERS = errors.New(errno.EUSERS, "too many users") + ENOTSOCK = errors.New(errno.ENOTSOCK, "socket operation on non-socket") + EDESTADDRREQ = errors.New(errno.EDESTADDRREQ, "destination address required") + EMSGSIZE = errors.New(errno.EMSGSIZE, "message too long") + EPROTOTYPE = errors.New(errno.EPROTOTYPE, "protocol wrong type for socket") + ENOPROTOOPT = errors.New(errno.ENOPROTOOPT, "protocol not available") + EPROTONOSUPPORT = errors.New(errno.EPROTONOSUPPORT, "protocol not supported") + ESOCKTNOSUPPORT = errors.New(errno.ESOCKTNOSUPPORT, "socket type not supported") + EOPNOTSUPP = errors.New(errno.EOPNOTSUPP, "operation not supported on transport endpoint") + EPFNOSUPPORT = errors.New(errno.EPFNOSUPPORT, "protocol family not supported") + EAFNOSUPPORT = errors.New(errno.EAFNOSUPPORT, "address family not supported by protocol") + EADDRINUSE = errors.New(errno.EADDRINUSE, "address already in use") + EADDRNOTAVAIL = errors.New(errno.EADDRNOTAVAIL, "cannot assign requested address") + ENETDOWN = errors.New(errno.ENETDOWN, "network is down") + ENETUNREACH = errors.New(errno.ENETUNREACH, "network is unreachable") + ENETRESET = errors.New(errno.ENETRESET, "network dropped connection because of reset") + ECONNABORTED = errors.New(errno.ECONNABORTED, "software caused connection abort") + ECONNRESET = errors.New(errno.ECONNRESET, "connection reset by peer") + ENOBUFS = errors.New(errno.ENOBUFS, "no buffer space available") + EISCONN = errors.New(errno.EISCONN, "transport endpoint is already connected") + ENOTCONN = errors.New(errno.ENOTCONN, "transport endpoint is not connected") + ESHUTDOWN = errors.New(errno.ESHUTDOWN, "cannot send after transport endpoint shutdown") + ETOOMANYREFS = errors.New(errno.ETOOMANYREFS, "too many references: cannot splice") + ETIMEDOUT = errors.New(errno.ETIMEDOUT, "connection timed out") + ECONNREFUSED = errors.New(errno.ECONNREFUSED, "connection refused") + EHOSTDOWN = errors.New(errno.EHOSTDOWN, "host is down") + EHOSTUNREACH = errors.New(errno.EHOSTUNREACH, "no route to host") + EALREADY = errors.New(errno.EALREADY, "operation already in progress") + EINPROGRESS = errors.New(errno.EINPROGRESS, "operation now in progress") + ESTALE = errors.New(errno.ESTALE, "stale file handle") + EUCLEAN = errors.New(errno.EUCLEAN, "structure needs cleaning") + ENOTNAM = errors.New(errno.ENOTNAM, "not a XENIX named type file") + ENAVAIL = errors.New(errno.ENAVAIL, "no XENIX semaphores available") + EISNAM = errors.New(errno.EISNAM, "is a named type file") + EREMOTEIO = errors.New(errno.EREMOTEIO, "remote I/O error") + EDQUOT = errors.New(errno.EDQUOT, "quota exceeded") + ENOMEDIUM = errors.New(errno.ENOMEDIUM, "no medium found") + EMEDIUMTYPE = errors.New(errno.EMEDIUMTYPE, "wrong medium type") + ECANCELED = errors.New(errno.ECANCELED, "operation Canceled") + ENOKEY = errors.New(errno.ENOKEY, "required key not available") + EKEYEXPIRED = errors.New(errno.EKEYEXPIRED, "key has expired") + EKEYREVOKED = errors.New(errno.EKEYREVOKED, "key has been revoked") + EKEYREJECTED = errors.New(errno.EKEYREJECTED, "key was rejected by service") + EOWNERDEAD = errors.New(errno.EOWNERDEAD, "owner died") + ENOTRECOVERABLE = errors.New(errno.ENOTRECOVERABLE, "state not recoverable") + ERFKILL = errors.New(errno.ERFKILL, "operation not possible due to RF-kill") + EHWPOISON = errors.New(errno.EHWPOISON, "memory page has hardware error") + + // Errors equivalent to other errors. + EWOULDBLOCK = EAGAIN + EDEADLOCK = EDEADLK + ENONET = ENOENT +) + +// A nil *errors.Error denotes no error and is placed at the 0 index of +// errorSlice. Thus, any other empty index should not be nil or a valid error. +// This marks that index as an invalid error so any comparison to nil or a +// valid linuxerr fails. +var errNotValidError = errors.New(errno.Errno(maxErrno), "not a valid error") + +// The following errorSlice holds errors by errno for fast translation between +// errnos (especially uint32(sycall.Errno)) and *Error. +var errorSlice = []*errors.Error{ + // Errno values from include/uapi/asm-generic/errno-base.h. + errno.NOERRNO: NOERROR, + errno.EPERM: EPERM, + errno.ENOENT: ENOENT, + errno.ESRCH: ESRCH, + errno.EINTR: EINTR, + errno.EIO: EIO, + errno.ENXIO: ENXIO, + errno.E2BIG: E2BIG, + errno.ENOEXEC: ENOEXEC, + errno.EBADF: EBADF, + errno.ECHILD: ECHILD, + errno.EAGAIN: EAGAIN, + errno.ENOMEM: ENOMEM, + errno.EACCES: EACCES, + errno.EFAULT: EFAULT, + errno.ENOTBLK: ENOTBLK, + errno.EBUSY: EBUSY, + errno.EEXIST: EEXIST, + errno.EXDEV: EXDEV, + errno.ENODEV: ENODEV, + errno.ENOTDIR: ENOTDIR, + errno.EISDIR: EISDIR, + errno.EINVAL: EINVAL, + errno.ENFILE: ENFILE, + errno.EMFILE: EMFILE, + errno.ENOTTY: ENOTTY, + errno.ETXTBSY: ETXTBSY, + errno.EFBIG: EFBIG, + errno.ENOSPC: ENOSPC, + errno.ESPIPE: ESPIPE, + errno.EROFS: EROFS, + errno.EMLINK: EMLINK, + errno.EPIPE: EPIPE, + errno.EDOM: EDOM, + errno.ERANGE: ERANGE, + + // Errno values from include/uapi/asm-generic/errno.h. + errno.EDEADLK: EDEADLK, + errno.ENAMETOOLONG: ENAMETOOLONG, + errno.ENOLCK: ENOLCK, + errno.ENOSYS: ENOSYS, + errno.ENOTEMPTY: ENOTEMPTY, + errno.ELOOP: ELOOP, + errno.ELOOP + 1: errNotValidError, // No valid errno between ELOOP and ENOMSG. + errno.ENOMSG: ENOMSG, + errno.EIDRM: EIDRM, + errno.ECHRNG: ECHRNG, + errno.EL2NSYNC: EL2NSYNC, + errno.EL3HLT: EL3HLT, + errno.EL3RST: EL3RST, + errno.ELNRNG: ELNRNG, + errno.EUNATCH: EUNATCH, + errno.ENOCSI: ENOCSI, + errno.EL2HLT: EL2HLT, + errno.EBADE: EBADE, + errno.EBADR: EBADR, + errno.EXFULL: EXFULL, + errno.ENOANO: ENOANO, + errno.EBADRQC: EBADRQC, + errno.EBADSLT: EBADSLT, + errno.EBADSLT + 1: errNotValidError, // No valid errno between EBADSLT and ENOPKG. + errno.EBFONT: EBFONT, + errno.ENOSTR: ENOSTR, + errno.ENODATA: ENODATA, + errno.ETIME: ETIME, + errno.ENOSR: ENOSR, + errno.ENOSR + 1: errNotValidError, // No valid errno betweeen ENOSR and ENOPKG. + errno.ENOPKG: ENOPKG, + errno.EREMOTE: EREMOTE, + errno.ENOLINK: ENOLINK, + errno.EADV: EADV, + errno.ESRMNT: ESRMNT, + errno.ECOMM: ECOMM, + errno.EPROTO: EPROTO, + errno.EMULTIHOP: EMULTIHOP, + errno.EDOTDOT: EDOTDOT, + errno.EBADMSG: EBADMSG, + errno.EOVERFLOW: EOVERFLOW, + errno.ENOTUNIQ: ENOTUNIQ, + errno.EBADFD: EBADFD, + errno.EREMCHG: EREMCHG, + errno.ELIBACC: ELIBACC, + errno.ELIBBAD: ELIBBAD, + errno.ELIBSCN: ELIBSCN, + errno.ELIBMAX: ELIBMAX, + errno.ELIBEXEC: ELIBEXEC, + errno.EILSEQ: EILSEQ, + errno.ERESTART: ERESTART, + errno.ESTRPIPE: ESTRPIPE, + errno.EUSERS: EUSERS, + errno.ENOTSOCK: ENOTSOCK, + errno.EDESTADDRREQ: EDESTADDRREQ, + errno.EMSGSIZE: EMSGSIZE, + errno.EPROTOTYPE: EPROTOTYPE, + errno.ENOPROTOOPT: ENOPROTOOPT, + errno.EPROTONOSUPPORT: EPROTONOSUPPORT, + errno.ESOCKTNOSUPPORT: ESOCKTNOSUPPORT, + errno.EOPNOTSUPP: EOPNOTSUPP, + errno.EPFNOSUPPORT: EPFNOSUPPORT, + errno.EAFNOSUPPORT: EAFNOSUPPORT, + errno.EADDRINUSE: EADDRINUSE, + errno.EADDRNOTAVAIL: EADDRNOTAVAIL, + errno.ENETDOWN: ENETDOWN, + errno.ENETUNREACH: ENETUNREACH, + errno.ENETRESET: ENETRESET, + errno.ECONNABORTED: ECONNABORTED, + errno.ECONNRESET: ECONNRESET, + errno.ENOBUFS: ENOBUFS, + errno.EISCONN: EISCONN, + errno.ENOTCONN: ENOTCONN, + errno.ESHUTDOWN: ESHUTDOWN, + errno.ETOOMANYREFS: ETOOMANYREFS, + errno.ETIMEDOUT: ETIMEDOUT, + errno.ECONNREFUSED: ECONNREFUSED, + errno.EHOSTDOWN: EHOSTDOWN, + errno.EHOSTUNREACH: EHOSTUNREACH, + errno.EALREADY: EALREADY, + errno.EINPROGRESS: EINPROGRESS, + errno.ESTALE: ESTALE, + errno.EUCLEAN: EUCLEAN, + errno.ENOTNAM: ENOTNAM, + errno.ENAVAIL: ENAVAIL, + errno.EISNAM: EISNAM, + errno.EREMOTEIO: EREMOTEIO, + errno.EDQUOT: EDQUOT, + errno.ENOMEDIUM: ENOMEDIUM, + errno.EMEDIUMTYPE: EMEDIUMTYPE, + errno.ECANCELED: ECANCELED, + errno.ENOKEY: ENOKEY, + errno.EKEYEXPIRED: EKEYEXPIRED, + errno.EKEYREVOKED: EKEYREVOKED, + errno.EKEYREJECTED: EKEYREJECTED, + errno.EOWNERDEAD: EOWNERDEAD, + errno.ENOTRECOVERABLE: ENOTRECOVERABLE, + errno.ERFKILL: ERFKILL, + errno.EHWPOISON: EHWPOISON, +} + +// ErrorFromErrno gets an error from the list and panics if an invalid entry is requested. +func ErrorFromErrno(e errno.Errno) *errors.Error { + err := errorSlice[e] + // Done this way because a single comparison in benchmarks is 2-3 faster + // than something like ( if err == nil && err > 0 ). + if err != errNotValidError { + return err + } + panic(fmt.Sprintf("invalid error requested with errno: %d", e)) +} diff --git a/pkg/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go index d34937e93..a81dd9560 100644 --- a/pkg/linuxerr/linuxerr_test.go +++ b/pkg/errors/linuxerr/linuxerr_test.go @@ -16,34 +16,37 @@ package syserror_test import ( "errors" + "syscall" "testing" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/linuxerr" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + gErrors "gvisor.dev/gvisor/pkg/errors" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) var globalError error -func BenchmarkAssignErrno(b *testing.B) { +func BenchmarkAssignUnix(b *testing.B) { for i := b.N; i > 0; i-- { globalError = unix.EINVAL } } -func BenchmarkLinuxerrAssignError(b *testing.B) { +func BenchmarkAssignLinuxerr(b *testing.B) { for i := b.N; i > 0; i-- { globalError = linuxerr.EINVAL } } -func BenchmarkAssignSyserrorError(b *testing.B) { +func BenchmarkAssignSyserror(b *testing.B) { for i := b.N; i > 0; i-- { globalError = syserror.EINVAL } } -func BenchmarkCompareErrno(b *testing.B) { +func BenchmarkCompareUnix(b *testing.B) { globalError = unix.EAGAIN j := 0 for i := b.N; i > 0; i-- { @@ -53,7 +56,7 @@ func BenchmarkCompareErrno(b *testing.B) { } } -func BenchmarkCompareLinuxerrError(b *testing.B) { +func BenchmarkCompareLinuxerr(b *testing.B) { globalError = linuxerr.E2BIG j := 0 for i := b.N; i > 0; i-- { @@ -63,7 +66,7 @@ func BenchmarkCompareLinuxerrError(b *testing.B) { } } -func BenchmarkCompareSyserrorError(b *testing.B) { +func BenchmarkCompareSyserror(b *testing.B) { globalError = syserror.EAGAIN j := 0 for i := b.N; i > 0; i-- { @@ -73,7 +76,7 @@ func BenchmarkCompareSyserrorError(b *testing.B) { } } -func BenchmarkSwitchErrno(b *testing.B) { +func BenchmarkSwitchUnix(b *testing.B) { globalError = unix.EPERM j := 0 for i := b.N; i > 0; i-- { @@ -88,7 +91,7 @@ func BenchmarkSwitchErrno(b *testing.B) { } } -func BenchmarkSwitchLinuxerrError(b *testing.B) { +func BenchmarkSwitchLinuxerr(b *testing.B) { globalError = linuxerr.EPERM j := 0 for i := b.N; i > 0; i-- { @@ -103,7 +106,7 @@ func BenchmarkSwitchLinuxerrError(b *testing.B) { } } -func BenchmarkSwitchSyserrorError(b *testing.B) { +func BenchmarkSwitchSyserror(b *testing.B) { globalError = syserror.EPERM j := 0 for i := b.N; i > 0; i-- { @@ -118,6 +121,52 @@ func BenchmarkSwitchSyserrorError(b *testing.B) { } } +func BenchmarkReturnUnix(b *testing.B) { + var localError error + f := func() error { + return unix.EINVAL + } + for i := b.N; i > 0; i-- { + localError = f() + } + if localError != nil { + return + } +} + +func BenchmarkReturnLinuxerr(b *testing.B) { + var localError error + f := func() error { + return linuxerr.EINVAL + } + for i := b.N; i > 0; i-- { + localError = f() + } + if localError != nil { + return + } +} + +func BenchmarkConvertUnixLinuxerr(b *testing.B) { + var localError error + for i := b.N; i > 0; i-- { + localError = linuxerr.ErrorFromErrno(errno.Errno(unix.EINVAL)) + } + if localError != nil { + return + } +} + +func BenchmarkConvertUnixLinuxerrZero(b *testing.B) { + var localError error + for i := b.N; i > 0; i-- { + localError = linuxerr.ErrorFromErrno(errno.Errno(0)) + } + if localError != nil { + return + } +} + type translationTestTable struct { fn string errIn error @@ -162,3 +211,35 @@ func TestErrorTranslation(t *testing.T) { } } } + +func TestSyscallErrnoToErrors(t *testing.T) { + for _, tc := range []struct { + errno syscall.Errno + err *gErrors.Error + }{ + {errno: syscall.EACCES, err: linuxerr.EACCES}, + {errno: syscall.EAGAIN, err: linuxerr.EAGAIN}, + {errno: syscall.EBADF, err: linuxerr.EBADF}, + {errno: syscall.EBUSY, err: linuxerr.EBUSY}, + {errno: syscall.EDOM, err: linuxerr.EDOM}, + {errno: syscall.EEXIST, err: linuxerr.EEXIST}, + {errno: syscall.EFAULT, err: linuxerr.EFAULT}, + {errno: syscall.EFBIG, err: linuxerr.EFBIG}, + {errno: syscall.EINTR, err: linuxerr.EINTR}, + {errno: syscall.EINVAL, err: linuxerr.EINVAL}, + {errno: syscall.EIO, err: linuxerr.EIO}, + {errno: syscall.ENOTDIR, err: linuxerr.ENOTDIR}, + {errno: syscall.ENOTTY, err: linuxerr.ENOTTY}, + {errno: syscall.EPERM, err: linuxerr.EPERM}, + {errno: syscall.EPIPE, err: linuxerr.EPIPE}, + {errno: syscall.ESPIPE, err: linuxerr.ESPIPE}, + {errno: syscall.EWOULDBLOCK, err: linuxerr.EAGAIN}, + } { + t.Run(tc.errno.Error(), func(t *testing.T) { + e := linuxerr.ErrorFromErrno(errno.Errno(tc.errno)) + if e != tc.err { + t.Fatalf("Mismatch errors: want: %+v (%d) got: %+v %d", tc.err, tc.err.Errno(), e, e.Errno()) + } + }) + } +} diff --git a/pkg/linuxerr/linuxerr.go b/pkg/linuxerr/linuxerr.go deleted file mode 100644 index f45caaadf..000000000 --- a/pkg/linuxerr/linuxerr.go +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package linuxerr contains syscall error codes exported as an error interface -// pointers. This allows for fast comparison and return operations comperable -// to unix.Errno constants. -package linuxerr - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" -) - -// Error represents a syscall errno with a descriptive message. -type Error struct { - errno linux.Errno - message string -} - -func new(err linux.Errno, message string) *Error { - return &Error{ - errno: err, - message: message, - } -} - -// Error implements error.Error. -func (e *Error) Error() string { return e.message } - -// Errno returns the underlying linux.Errno value. -func (e *Error) Errno() linux.Errno { return e.errno } - -// The following varables have the same meaning as their errno equivalent. - -// Errno values from include/uapi/asm-generic/errno-base.h. -var ( - EPERM = new(linux.EPERM, "operation not permitted") - ENOENT = new(linux.ENOENT, "no such file or directory") - ESRCH = new(linux.ESRCH, "no such process") - EINTR = new(linux.EINTR, "interrupted system call") - EIO = new(linux.EIO, "I/O error") - ENXIO = new(linux.ENXIO, "no such device or address") - E2BIG = new(linux.E2BIG, "argument list too long") - ENOEXEC = new(linux.ENOEXEC, "exec format error") - EBADF = new(linux.EBADF, "bad file number") - ECHILD = new(linux.ECHILD, "no child processes") - EAGAIN = new(linux.EAGAIN, "try again") - ENOMEM = new(linux.ENOMEM, "out of memory") - EACCES = new(linux.EACCES, "permission denied") - EFAULT = new(linux.EFAULT, "bad address") - ENOTBLK = new(linux.ENOTBLK, "block device required") - EBUSY = new(linux.EBUSY, "device or resource busy") - EEXIST = new(linux.EEXIST, "file exists") - EXDEV = new(linux.EXDEV, "cross-device link") - ENODEV = new(linux.ENODEV, "no such device") - ENOTDIR = new(linux.ENOTDIR, "not a directory") - EISDIR = new(linux.EISDIR, "is a directory") - EINVAL = new(linux.EINVAL, "invalid argument") - ENFILE = new(linux.ENFILE, "file table overflow") - EMFILE = new(linux.EMFILE, "too many open files") - ENOTTY = new(linux.ENOTTY, "not a typewriter") - ETXTBSY = new(linux.ETXTBSY, "text file busy") - EFBIG = new(linux.EFBIG, "file too large") - ENOSPC = new(linux.ENOSPC, "no space left on device") - ESPIPE = new(linux.ESPIPE, "illegal seek") - EROFS = new(linux.EROFS, "read-only file system") - EMLINK = new(linux.EMLINK, "too many links") - EPIPE = new(linux.EPIPE, "broken pipe") - EDOM = new(linux.EDOM, "math argument out of domain of func") - ERANGE = new(linux.ERANGE, "math result not representable") -) - -// Errno values from include/uapi/asm-generic/errno.h. -var ( - EDEADLK = new(linux.EDEADLK, "resource deadlock would occur") - ENAMETOOLONG = new(linux.ENAMETOOLONG, "file name too long") - ENOLCK = new(linux.ENOLCK, "no record locks available") - ENOSYS = new(linux.ENOSYS, "invalid system call number") - ENOTEMPTY = new(linux.ENOTEMPTY, "directory not empty") - ELOOP = new(linux.ELOOP, "too many symbolic links encountered") - EWOULDBLOCK = new(linux.EWOULDBLOCK, "operation would block") - ENOMSG = new(linux.ENOMSG, "no message of desired type") - EIDRM = new(linux.EIDRM, "identifier removed") - ECHRNG = new(linux.ECHRNG, "channel number out of range") - EL2NSYNC = new(linux.EL2NSYNC, "level 2 not synchronized") - EL3HLT = new(linux.EL3HLT, "level 3 halted") - EL3RST = new(linux.EL3RST, "level 3 reset") - ELNRNG = new(linux.ELNRNG, "link number out of range") - EUNATCH = new(linux.EUNATCH, "protocol driver not attached") - ENOCSI = new(linux.ENOCSI, "no CSI structure available") - EL2HLT = new(linux.EL2HLT, "level 2 halted") - EBADE = new(linux.EBADE, "invalid exchange") - EBADR = new(linux.EBADR, "invalid request descriptor") - EXFULL = new(linux.EXFULL, "exchange full") - ENOANO = new(linux.ENOANO, "no anode") - EBADRQC = new(linux.EBADRQC, "invalid request code") - EBADSLT = new(linux.EBADSLT, "invalid slot") - EDEADLOCK = new(linux.EDEADLOCK, EDEADLK.message) - EBFONT = new(linux.EBFONT, "bad font file format") - ENOSTR = new(linux.ENOSTR, "device not a stream") - ENODATA = new(linux.ENODATA, "no data available") - ETIME = new(linux.ETIME, "timer expired") - ENOSR = new(linux.ENOSR, "out of streams resources") - ENONET = new(linux.ENOENT, "machine is not on the network") - ENOPKG = new(linux.ENOPKG, "package not installed") - EREMOTE = new(linux.EREMOTE, "object is remote") - ENOLINK = new(linux.ENOLINK, "link has been severed") - EADV = new(linux.EADV, "advertise error") - ESRMNT = new(linux.ESRMNT, "srmount error") - ECOMM = new(linux.ECOMM, "communication error on send") - EPROTO = new(linux.EPROTO, "protocol error") - EMULTIHOP = new(linux.EMULTIHOP, "multihop attempted") - EDOTDOT = new(linux.EDOTDOT, "RFS specific error") - EBADMSG = new(linux.EBADMSG, "not a data message") - EOVERFLOW = new(linux.EOVERFLOW, "value too large for defined data type") - ENOTUNIQ = new(linux.ENOTUNIQ, "name not unique on network") - EBADFD = new(linux.EBADFD, "file descriptor in bad state") - EREMCHG = new(linux.EREMCHG, "remote address changed") - ELIBACC = new(linux.ELIBACC, "can not access a needed shared library") - ELIBBAD = new(linux.ELIBBAD, "accessing a corrupted shared library") - ELIBSCN = new(linux.ELIBSCN, ".lib section in a.out corrupted") - ELIBMAX = new(linux.ELIBMAX, "attempting to link in too many shared libraries") - ELIBEXEC = new(linux.ELIBEXEC, "cannot exec a shared library directly") - EILSEQ = new(linux.EILSEQ, "illegal byte sequence") - ERESTART = new(linux.ERESTART, "interrupted system call should be restarted") - ESTRPIPE = new(linux.ESTRPIPE, "streams pipe error") - EUSERS = new(linux.EUSERS, "too many users") - ENOTSOCK = new(linux.ENOTSOCK, "socket operation on non-socket") - EDESTADDRREQ = new(linux.EDESTADDRREQ, "destination address required") - EMSGSIZE = new(linux.EMSGSIZE, "message too long") - EPROTOTYPE = new(linux.EPROTOTYPE, "protocol wrong type for socket") - ENOPROTOOPT = new(linux.ENOPROTOOPT, "protocol not available") - EPROTONOSUPPORT = new(linux.EPROTONOSUPPORT, "protocol not supported") - ESOCKTNOSUPPORT = new(linux.ESOCKTNOSUPPORT, "socket type not supported") - EOPNOTSUPP = new(linux.EOPNOTSUPP, "operation not supported on transport endpoint") - EPFNOSUPPORT = new(linux.EPFNOSUPPORT, "protocol family not supported") - EAFNOSUPPORT = new(linux.EAFNOSUPPORT, "address family not supported by protocol") - EADDRINUSE = new(linux.EADDRINUSE, "address already in use") - EADDRNOTAVAIL = new(linux.EADDRNOTAVAIL, "cannot assign requested address") - ENETDOWN = new(linux.ENETDOWN, "network is down") - ENETUNREACH = new(linux.ENETUNREACH, "network is unreachable") - ENETRESET = new(linux.ENETRESET, "network dropped connection because of reset") - ECONNABORTED = new(linux.ECONNABORTED, "software caused connection abort") - ECONNRESET = new(linux.ECONNRESET, "connection reset by peer") - ENOBUFS = new(linux.ENOBUFS, "no buffer space available") - EISCONN = new(linux.EISCONN, "transport endpoint is already connected") - ENOTCONN = new(linux.ENOTCONN, "transport endpoint is not connected") - ESHUTDOWN = new(linux.ESHUTDOWN, "cannot send after transport endpoint shutdown") - ETOOMANYREFS = new(linux.ETOOMANYREFS, "too many references: cannot splice") - ETIMEDOUT = new(linux.ETIMEDOUT, "connection timed out") - ECONNREFUSED = new(linux.ECONNREFUSED, "connection refused") - EHOSTDOWN = new(linux.EHOSTDOWN, "host is down") - EHOSTUNREACH = new(linux.EHOSTUNREACH, "no route to host") - EALREADY = new(linux.EALREADY, "operation already in progress") - EINPROGRESS = new(linux.EINPROGRESS, "operation now in progress") - ESTALE = new(linux.ESTALE, "stale file handle") - EUCLEAN = new(linux.EUCLEAN, "structure needs cleaning") - ENOTNAM = new(linux.ENOTNAM, "not a XENIX named type file") - ENAVAIL = new(linux.ENAVAIL, "no XENIX semaphores available") - EISNAM = new(linux.EISNAM, "is a named type file") - EREMOTEIO = new(linux.EREMOTEIO, "remote I/O error") - EDQUOT = new(linux.EDQUOT, "quota exceeded") - ENOMEDIUM = new(linux.ENOMEDIUM, "no medium found") - EMEDIUMTYPE = new(linux.EMEDIUMTYPE, "wrong medium type") - ECANCELED = new(linux.ECANCELED, "operation Canceled") - ENOKEY = new(linux.ENOKEY, "required key not available") - EKEYEXPIRED = new(linux.EKEYEXPIRED, "key has expired") - EKEYREVOKED = new(linux.EKEYREVOKED, "key has been revoked") - EKEYREJECTED = new(linux.EKEYREJECTED, "key was rejected by service") - EOWNERDEAD = new(linux.EOWNERDEAD, "owner died") - ENOTRECOVERABLE = new(linux.ENOTRECOVERABLE, "state not recoverable") - ERFKILL = new(linux.ERFKILL, "operation not possible due to RF-kill") - EHWPOISON = new(linux.EHWPOISON, "memory page has hardware error") -) diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD index 190b57c29..6e5ce136d 100644 --- a/pkg/marshal/primitive/BUILD +++ b/pkg/marshal/primitive/BUILD @@ -12,9 +12,7 @@ go_library( "//:sandbox", ], deps = [ - "//pkg/context", "//pkg/hostarch", "//pkg/marshal", - "//pkg/usermem", ], ) diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 6f38992b7..1c49cf082 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -19,10 +19,8 @@ package primitive import ( "io" - "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" - "gvisor.dev/gvisor/pkg/usermem" ) // Int8 is a marshal.Marshallable implementation for int8. @@ -400,26 +398,3 @@ func CopyStringOut(cc marshal.CopyContext, addr hostarch.Addr, src string) (int, srcP := ByteSlice(src) return srcP.CopyOut(cc, addr) } - -// IOCopyContext wraps an object implementing hostarch.IO to implement -// marshal.CopyContext. -type IOCopyContext struct { - Ctx context.Context - IO usermem.IO - Opts usermem.IOOpts -} - -// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. -func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { - return make([]byte, size) -} - -// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. -func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) { - return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) -} - -// CopyInBytes implements marshal.CopyContext.CopyInBytes. -func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) { - return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) -} diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index fdeee3a5f..4829ae7ce 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -36,17 +36,22 @@ var ( // new metric after initialization. ErrInitializationDone = errors.New("metric cannot be created after initialization is complete") - // createdSentryMetrics indicates that the sentry metrics are created. - createdSentryMetrics = false - // WeirdnessMetric is a metric with fields created to track the number // of weird occurrences such as time fallback, partial_result, vsyscall // count, watchdog startup timeouts and stuck tasks. - WeirdnessMetric *Uint64Metric + WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", + Field{ + name: "weirdness_type", + allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, + }) // SuspiciousOperationsMetric is a metric with fields created to detect // operations such as opening an executable file to write from a gofer. - SuspiciousOperationsMetric *Uint64Metric + SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", + Field{ + name: "operation_type", + allowedValues: []string{"opened_write_execute_file"}, + }) ) // Uint64Metric encapsulates a uint64 that represents some kind of metric to be @@ -84,17 +89,21 @@ var ( // Precondition: // * All metrics are registered. // * Initialize/Disable has not been called. -func Initialize() { +func Initialize() error { if initialized { - panic("Initialize/Disable called more than once") + return errors.New("metric.Initialize called after metric.Initialize or metric.Disable") } - initialized = true m := pb.MetricRegistration{} for _, v := range allMetrics.m { m.Metrics = append(m.Metrics, v.metadata) } - eventchannel.Emit(&m) + if err := eventchannel.Emit(&m); err != nil { + return fmt.Errorf("unable to emit metric initialize event: %w", err) + } + + initialized = true + return nil } // Disable sends an empty metric registration event over the event channel, @@ -103,16 +112,18 @@ func Initialize() { // Precondition: // * All metrics are registered. // * Initialize/Disable has not been called. -func Disable() { +func Disable() error { if initialized { - panic("Initialize/Disable called more than once") + return errors.New("metric.Disable called after metric.Initialize or metric.Disable") } - initialized = true m := pb.MetricRegistration{} if err := eventchannel.Emit(&m); err != nil { - panic("unable to emit metric disable event: " + err.Error()) + return fmt.Errorf("unable to emit metric disable event: %w", err) } + + initialized = true + return nil } type customUint64Metric struct { @@ -165,8 +176,8 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met } // Metrics can exist without fields. - if len(fields) > 1 { - panic("Sentry metrics support at most one field") + if l := len(fields); l > 1 { + return fmt.Errorf("%d fields provided, must be <= 1", l) } for _, field := range fields { @@ -182,7 +193,7 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met // without fields and panics if it returns an error. func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) { if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil { - panic(fmt.Sprintf("Unable to register metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to register metric %q: %s", name, err)) } } @@ -209,7 +220,7 @@ func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, desc func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric { m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...) if err != nil { - panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to create metric %q: %s", name, err)) } return m } @@ -219,7 +230,7 @@ func MustCreateNewUint64Metric(name string, sync bool, description string, field func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric { m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description) if err != nil { - panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to create metric %q: %s", name, err)) } return m } @@ -354,7 +365,7 @@ func EmitMetricUpdate() { m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, - Value: &pb.MetricValue_Uint64Value{t}, + Value: &pb.MetricValue_Uint64Value{Uint64Value: t}, }) case map[string]uint64: for fieldValue, metricValue := range t { @@ -369,7 +380,7 @@ func EmitMetricUpdate() { m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, FieldValues: []string{fieldValue}, - Value: &pb.MetricValue_Uint64Value{metricValue}, + Value: &pb.MetricValue_Uint64Value{Uint64Value: metricValue}, }) } } @@ -390,26 +401,7 @@ func EmitMetricUpdate() { } } - eventchannel.Emit(&m) -} - -// CreateSentryMetrics creates the sentry metrics during kernel initialization. -func CreateSentryMetrics() { - if createdSentryMetrics { - return + if err := eventchannel.Emit(&m); err != nil { + log.Warningf("Unable to emit metrics: %s", err) } - - createdSentryMetrics = true - - WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", - Field{ - name: "weirdness_type", - allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, - }) - - SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", - Field{ - name: "operation_type", - allowedValues: []string{"opened_write_execute_file"}, - }) } diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index c71dfd460..1b4a9e73a 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -48,6 +48,8 @@ func (s *sliceEmitter) Reset() { var emitter sliceEmitter func init() { + reset() + eventchannel.AddEmitter(&emitter) } @@ -77,7 +79,9 @@ func TestInitialize(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } if len(emitter) != 1 { t.Fatalf("Initialize emitted %d events want 1", len(emitter)) @@ -149,7 +153,9 @@ func TestDisable(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Disable() + if err := Disable(); err != nil { + t.Fatalf("Disable(): %s", err) + } if len(emitter) != 1 { t.Fatalf("Initialize emitted %d events want 1", len(emitter)) @@ -178,7 +184,9 @@ func TestEmitMetricUpdate(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } // Don't care about the registration metrics. emitter.Reset() @@ -270,7 +278,9 @@ func TestEmitMetricUpdateWithFields(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } // Don't care about the registration metrics. emitter.Reset() diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index b2291ef97..2b22b2203 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -22,6 +22,9 @@ go_library( "version.go", ], deps = [ + "//pkg/abi/linux/errno", + "//pkg/errors", + "//pkg/errors/linuxerr", "//pkg/fd", "//pkg/fdchannel", "//pkg/flipcall", diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 758e11b13..161b451cc 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -23,6 +23,9 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" ) @@ -62,6 +65,45 @@ func newErr(err error) *Rlerror { return &Rlerror{Error: uint32(ExtractErrno(err))} } +// ExtractLinuxerrErrno extracts a *errors.Error from a error, best effort. +// TODO(b/34162363): Merge this with ExtractErrno. +func ExtractLinuxerrErrno(err error) *errors.Error { + switch err { + case os.ErrNotExist: + return linuxerr.ENOENT + case os.ErrExist: + return linuxerr.EEXIST + case os.ErrPermission: + return linuxerr.EACCES + case os.ErrInvalid: + return linuxerr.EINVAL + } + + // Attempt to unwrap. + switch e := err.(type) { + case *errors.Error: + return e + case unix.Errno: + return linuxerr.ErrorFromErrno(errno.Errno(e)) + case *os.PathError: + return ExtractLinuxerrErrno(e.Err) + case *os.SyscallError: + return ExtractLinuxerrErrno(e.Err) + case *os.LinkError: + return ExtractLinuxerrErrno(e.Err) + } + + // Default case. + log.Warningf("unknown error: %v", err) + return linuxerr.EIO +} + +// newErrFromLinuxerr returns an Rlerror from the linuxerr list. +// TODO(b/34162363): Merge this with newErr. +func newErrFromLinuxerr(err error) *Rlerror { + return &Rlerror{Error: uint32(ExtractLinuxerrErrno(err).Errno())} +} + // handler is implemented for server-handled messages. // // See server.go for call information. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index ff1172ed6..241ab44ef 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -19,7 +19,8 @@ import ( "runtime/debug" "sync/atomic" - "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdchannel" "gvisor.dev/gvisor/pkg/flipcall" @@ -483,7 +484,7 @@ func (cs *connState) lookupChannel(id uint32) *channel { func (cs *connState) handle(m message) (r message) { if !cs.reqGate.Enter() { // connState.stop() has been called; the connection is shutting down. - r = newErr(unix.ECONNRESET) + r = newErrFromLinuxerr(linuxerr.ECONNRESET) return } defer func() { @@ -498,15 +499,23 @@ func (cs *connState) handle(m message) (r message) { // Wrap in an EFAULT error; we don't really have a // better way to describe this kind of error. It will // usually manifest as a result of the test framework. - r = newErr(unix.EFAULT) + r = newErrFromLinuxerr(linuxerr.EFAULT) } }() if handler, ok := m.(handler); ok { // Call the message handler. r = handler.handle(cs) + // TODO(b/34162363):This is only here to make sure the server works with + // only linuxerr Errors, as the handlers work with both client and server. + // It will be removed a followup, when all the unix.Errno errors are + // replaced with linuxerr. + if rlError, ok := r.(*Rlerror); ok { + e := linuxerr.ErrorFromErrno(errno.Errno(rlError.Error)) + r = newErrFromLinuxerr(e) + } } else { // Produce an ENOSYS error. - r = newErr(unix.ENOSYS) + r = newErrFromLinuxerr(linuxerr.ENOSYS) } return } @@ -553,7 +562,7 @@ func (cs *connState) handleRequest() bool { // If it's not a connection error, but some other protocol error, // we can send a response immediately. cs.sendMu.Lock() - err := send(cs.conn, tag, newErr(err)) + err := send(cs.conn, tag, newErrFromLinuxerr(err)) cs.sendMu.Unlock() if err != nil { log.Debugf("p9.send: %v", err) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 4aecb8007..1bbcae045 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -261,8 +261,8 @@ func (l *LeakMode) Get() interface{} { } // String implements flag.Value. -func (l *LeakMode) String() string { - switch *l { +func (l LeakMode) String() string { + switch l { case UninitializedLeakChecking: return "uninitialized" case NoLeakChecking: @@ -272,7 +272,7 @@ func (l *LeakMode) String() string { case LeaksLogTraces: return "log-traces" } - panic(fmt.Sprintf("invalid ref leak mode %d", *l)) + panic(fmt.Sprintf("invalid ref leak mode %d", l)) } // leakMode stores the current mode for the reference leak checker. diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 41dfd0bf9..f63af8b76 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -254,6 +254,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { return } +var sentryXCR0 = xgetbv(0) + // start is the CPU entrypoint. // // This is called from the Start asm stub (see entry_amd64.go); on return the @@ -265,24 +267,10 @@ func start(c *CPU) { WriteGS(kernelAddr(c.kernelEntry)) WriteFS(uintptr(c.registers.Fs_base)) - // Initialize floating point. - // - // Note that on skylake, the valid XCR0 mask reported seems to be 0xff. - // This breaks down as: - // - // bit0 - x87 - // bit1 - SSE - // bit2 - AVX - // bit3-4 - MPX - // bit5-7 - AVX512 - // - // For some reason, enabled MPX & AVX512 on platforms that report them - // seems to be cause a general protection fault. (Maybe there are some - // virtualization issues and these aren't exported to the guest cpuid.) - // This needs further investigation, but we can limit the floating - // point operations to x87, SSE & AVX for now. fninit() - xsetbv(0, validXCR0Mask&0x7) + // Need to sync XCR0 with the host, because xsave and xrstor can be + // called from different contexts. + xsetbv(0, sentryXCR0) // Set the syscall target. wrmsr(_MSR_LSTAR, kernelFunc(sysenter)) diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD index b77c40279..db5787302 100644 --- a/pkg/safecopy/BUILD +++ b/pkg/safecopy/BUILD @@ -18,6 +18,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", "//pkg/syserror", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index efbc2ddc1..2365b2c0d 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -20,6 +20,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" ) // maxRegisterSize is the maximum register size used in memcpy and memclr. It @@ -342,12 +343,7 @@ func errorFromFaultSignal(addr uintptr, sig int32) error { // handler however, and if this is function is being used externally then the // same courtesy is expected. func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { - var sa struct { - handler uintptr - flags uint64 - restorer uintptr - mask uint64 - } + var sa linux.SigAction const maskLen = 8 // Get the existing signal handler information, and save the current @@ -358,14 +354,14 @@ func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) e } // Fail if there isn't a previous handler. - if sa.handler == 0 { + if sa.Handler == 0 { return fmt.Errorf("previous handler for signal %x isn't set", sig) } - *previous = sa.handler + *previous = uintptr(sa.Handler) // Install our own handler. - sa.handler = handler + sa.Handler = uint64(handler) if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 { return e } diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index daea51c4d..8ffa1db37 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -36,14 +36,10 @@ const ( // Install generates BPF code based on the set of syscalls provided. It only // allows syscalls that conform to the specification. Syscalls that violate the -// specification will trigger RET_KILL_PROCESS, except for the cases below. -// -// RET_TRAP is used in violations, instead of RET_KILL_PROCESS, in the -// following cases: -// 1. Kernel doesn't support RET_KILL_PROCESS: RET_KILL_THREAD only kills the -// offending thread and often keeps the sentry hanging. -// 2. Debug: RET_TRAP generates a panic followed by a stack trace which is -// much easier to debug then RET_KILL_PROCESS which can't be caught. +// specification will trigger RET_KILL_PROCESS. If RET_KILL_PROCESS is not +// supported, violations will trigger RET_TRAP instead. RET_KILL_THREAD is not +// used because it only kills the offending thread and often keeps the sentry +// hanging. // // Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored, // making it possible for the process to continue running after a violation. diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index c9c52530d..61dacd2fb 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -14,12 +14,8 @@ go_library( "arch_x86.go", "arch_x86_impl.go", "auxv.go", - "signal.go", - "signal_act.go", "signal_amd64.go", "signal_arm64.go", - "signal_info.go", - "signal_stack.go", "stack.go", "stack_unsafe.go", "syscalls_amd64.go", diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 290863ee6..c9393b091 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -134,21 +134,13 @@ type Context interface { // RegisterMap returns a map of all registers. RegisterMap() (map[string]uintptr, error) - // NewSignalAct returns a new object that is equivalent to struct sigaction - // in the guest architecture. - NewSignalAct() NativeSignalAct - - // NewSignalStack returns a new object that is equivalent to stack_t in the - // guest architecture. - NewSignalStack() NativeSignalStack - // SignalSetup modifies the context in preparation for handling the // given signal. // // st is the stack where the signal handler frame should be // constructed. // - // act is the SignalAct that specifies how this signal is being + // act is the SigAction that specifies how this signal is being // handled. // // info is the SignalInfo of the signal being delivered. @@ -157,7 +149,7 @@ type Context interface { // stack is not going to be used). // // sigset is the signal mask before entering the signal handler. - SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error + SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error // SignalRestore restores context after returning from a signal // handler. @@ -167,7 +159,7 @@ type Context interface { // rt is true if SignalRestore is being entered from rt_sigreturn and // false if SignalRestore is being entered from sigreturn. // SignalRestore returns the thread's new signal mask. - SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) + SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) // CPUIDEmulate emulates a CPUID instruction according to current register state. CPUIDEmulate(l log.Logger) diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go deleted file mode 100644 index 67d7edf68..000000000 --- a/pkg/sentry/arch/signal.go +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arch - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/hostarch" -) - -// SignalAct represents the action that should be taken when a signal is -// delivered, and is equivalent to struct sigaction. -// -// +marshal -// +stateify savable -type SignalAct struct { - Handler uint64 - Flags uint64 - Restorer uint64 // Only used on amd64. - Mask linux.SignalSet -} - -// SerializeFrom implements NativeSignalAct.SerializeFrom. -func (s *SignalAct) SerializeFrom(other *SignalAct) { - *s = *other -} - -// DeserializeTo implements NativeSignalAct.DeserializeTo. -func (s *SignalAct) DeserializeTo(other *SignalAct) { - *other = *s -} - -// SignalStack represents information about a user stack, and is equivalent to -// stack_t. -// -// +marshal -// +stateify savable -type SignalStack struct { - Addr uint64 - Flags uint32 - _ uint32 - Size uint64 -} - -// SerializeFrom implements NativeSignalStack.SerializeFrom. -func (s *SignalStack) SerializeFrom(other *SignalStack) { - *s = *other -} - -// DeserializeTo implements NativeSignalStack.DeserializeTo. -func (s *SignalStack) DeserializeTo(other *SignalStack) { - *other = *s -} - -// SignalInfo represents information about a signal being delivered, and is -// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h). -// -// +marshal -// +stateify savable -type SignalInfo struct { - Signo int32 // Signal number - Errno int32 // Errno value - Code int32 // Signal code - _ uint32 - - // struct siginfo::_sifields is a union. In SignalInfo, fields in the union - // are accessed through methods. - // - // For reference, here is the definition of _sifields: (_sigfault._trapno, - // which does not exist on x86, omitted for clarity) - // - // union { - // int _pad[SI_PAD_SIZE]; - // - // /* kill() */ - // struct { - // __kernel_pid_t _pid; /* sender's pid */ - // __ARCH_SI_UID_T _uid; /* sender's uid */ - // } _kill; - // - // /* POSIX.1b timers */ - // struct { - // __kernel_timer_t _tid; /* timer id */ - // int _overrun; /* overrun count */ - // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)]; - // sigval_t _sigval; /* same as below */ - // int _sys_private; /* not to be passed to user */ - // } _timer; - // - // /* POSIX.1b signals */ - // struct { - // __kernel_pid_t _pid; /* sender's pid */ - // __ARCH_SI_UID_T _uid; /* sender's uid */ - // sigval_t _sigval; - // } _rt; - // - // /* SIGCHLD */ - // struct { - // __kernel_pid_t _pid; /* which child */ - // __ARCH_SI_UID_T _uid; /* sender's uid */ - // int _status; /* exit code */ - // __ARCH_SI_CLOCK_T _utime; - // __ARCH_SI_CLOCK_T _stime; - // } _sigchld; - // - // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */ - // struct { - // void *_addr; /* faulting insn/memory ref. */ - // short _addr_lsb; /* LSB of the reported address */ - // } _sigfault; - // - // /* SIGPOLL */ - // struct { - // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */ - // int _fd; - // } _sigpoll; - // - // /* SIGSYS */ - // struct { - // void *_call_addr; /* calling user insn */ - // int _syscall; /* triggering system call number */ - // unsigned int _arch; /* AUDIT_ARCH_* of syscall */ - // } _sigsys; - // } _sifields; - // - // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128 - // bytes. - Fields [128 - 16]byte -} - -// FixSignalCodeForUser fixes up si_code. -// -// The si_code we get from Linux may contain the kernel-specific code in the -// top 16 bits if it's positive (e.g., from ptrace). Linux's -// copy_siginfo_to_user does -// err |= __put_user((short)from->si_code, &to->si_code); -// to mask out those bits and we need to do the same. -func (s *SignalInfo) FixSignalCodeForUser() { - if s.Code > 0 { - s.Code &= 0x0000ffff - } -} - -// PID returns the si_pid field. -func (s *SignalInfo) PID() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4])) -} - -// SetPID mutates the si_pid field. -func (s *SignalInfo) SetPID(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) -} - -// UID returns the si_uid field. -func (s *SignalInfo) UID() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) -} - -// SetUID mutates the si_uid field. -func (s *SignalInfo) SetUID(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) -} - -// Sigval returns the sigval field, which is aliased to both si_int and si_ptr. -func (s *SignalInfo) Sigval() uint64 { - return hostarch.ByteOrder.Uint64(s.Fields[8:16]) -} - -// SetSigval mutates the sigval field. -func (s *SignalInfo) SetSigval(val uint64) { - hostarch.ByteOrder.PutUint64(s.Fields[8:16], val) -} - -// TimerID returns the si_timerid field. -func (s *SignalInfo) TimerID() linux.TimerID { - return linux.TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4])) -} - -// SetTimerID sets the si_timerid field. -func (s *SignalInfo) SetTimerID(val linux.TimerID) { - hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) -} - -// Overrun returns the si_overrun field. -func (s *SignalInfo) Overrun() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8])) -} - -// SetOverrun sets the si_overrun field. -func (s *SignalInfo) SetOverrun(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) -} - -// Addr returns the si_addr field. -func (s *SignalInfo) Addr() uint64 { - return hostarch.ByteOrder.Uint64(s.Fields[0:8]) -} - -// SetAddr sets the si_addr field. -func (s *SignalInfo) SetAddr(val uint64) { - hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) -} - -// Status returns the si_status field. -func (s *SignalInfo) Status() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) -} - -// SetStatus mutates the si_status field. -func (s *SignalInfo) SetStatus(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) -} - -// CallAddr returns the si_call_addr field. -func (s *SignalInfo) CallAddr() uint64 { - return hostarch.ByteOrder.Uint64(s.Fields[0:8]) -} - -// SetCallAddr mutates the si_call_addr field. -func (s *SignalInfo) SetCallAddr(val uint64) { - hostarch.ByteOrder.PutUint64(s.Fields[0:8], val) -} - -// Syscall returns the si_syscall field. -func (s *SignalInfo) Syscall() int32 { - return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12])) -} - -// SetSyscall mutates the si_syscall field. -func (s *SignalInfo) SetSyscall(val int32) { - hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val)) -} - -// Arch returns the si_arch field. -func (s *SignalInfo) Arch() uint32 { - return hostarch.ByteOrder.Uint32(s.Fields[12:16]) -} - -// SetArch mutates the si_arch field. -func (s *SignalInfo) SetArch(val uint32) { - hostarch.ByteOrder.PutUint32(s.Fields[12:16], val) -} - -// Band returns the si_band field. -func (s *SignalInfo) Band() int64 { - return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8])) -} - -// SetBand mutates the si_band field. -func (s *SignalInfo) SetBand(val int64) { - // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. - // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is - // `int`. See siginfo.h. - hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) -} - -// FD returns the si_fd field. -func (s *SignalInfo) FD() uint32 { - return hostarch.ByteOrder.Uint32(s.Fields[8:12]) -} - -// SetFD mutates the si_fd field. -func (s *SignalInfo) SetFD(val uint32) { - hostarch.ByteOrder.PutUint32(s.Fields[8:12], val) -} diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go deleted file mode 100644 index d3e2324a8..000000000 --- a/pkg/sentry/arch/signal_act.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arch - -import "gvisor.dev/gvisor/pkg/marshal" - -// Special values for SignalAct.Handler. -const ( - // SignalActDefault is SIG_DFL and specifies that the default behavior for - // a signal should be taken. - SignalActDefault = 0 - - // SignalActIgnore is SIG_IGN and specifies that a signal should be - // ignored. - SignalActIgnore = 1 -) - -// Available signal flags. -const ( - SignalFlagNoCldStop = 0x00000001 - SignalFlagNoCldWait = 0x00000002 - SignalFlagSigInfo = 0x00000004 - SignalFlagRestorer = 0x04000000 - SignalFlagOnStack = 0x08000000 - SignalFlagRestart = 0x10000000 - SignalFlagInterrupt = 0x20000000 - SignalFlagNoDefer = 0x40000000 - SignalFlagResetHandler = 0x80000000 -) - -// IsSigInfo returns true iff this handle expects siginfo. -func (s SignalAct) IsSigInfo() bool { - return s.Flags&SignalFlagSigInfo != 0 -} - -// IsNoDefer returns true iff this SignalAct has the NoDefer flag set. -func (s SignalAct) IsNoDefer() bool { - return s.Flags&SignalFlagNoDefer != 0 -} - -// IsRestart returns true iff this SignalAct has the Restart flag set. -func (s SignalAct) IsRestart() bool { - return s.Flags&SignalFlagRestart != 0 -} - -// IsResetHandler returns true iff this SignalAct has the ResetHandler flag set. -func (s SignalAct) IsResetHandler() bool { - return s.Flags&SignalFlagResetHandler != 0 -} - -// IsOnStack returns true iff this SignalAct has the OnStack flag set. -func (s SignalAct) IsOnStack() bool { - return s.Flags&SignalFlagOnStack != 0 -} - -// HasRestorer returns true iff this SignalAct has the Restorer flag set. -func (s SignalAct) HasRestorer() bool { - return s.Flags&SignalFlagRestorer != 0 -} - -// NativeSignalAct is a type that is equivalent to struct sigaction in the -// guest architecture. -type NativeSignalAct interface { - marshal.Marshallable - - // SerializeFrom copies the data in the host SignalAct s into this object. - SerializeFrom(s *SignalAct) - - // DeserializeTo copies the data in this object into the host SignalAct s. - DeserializeTo(s *SignalAct) -} diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 082ed92b1..58e28dbba 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -76,21 +76,11 @@ const ( type UContext64 struct { Flags uint64 Link uint64 - Stack SignalStack + Stack linux.SignalStack MContext SignalContext64 Sigset linux.SignalSet } -// NewSignalAct implements Context.NewSignalAct. -func (c *context64) NewSignalAct() NativeSignalAct { - return &SignalAct{} -} - -// NewSignalStack implements Context.NewSignalStack. -func (c *context64) NewSignalStack() NativeSignalStack { - return &SignalStack{} -} - // From Linux 'arch/x86/include/uapi/asm/sigcontext.h' the following is the // size of the magic cookie at the end of the xsave frame. // @@ -110,7 +100,7 @@ func (c *context64) fpuFrameSize() (size int, useXsave bool) { // SignalSetup implements Context.SignalSetup. (Compare to Linux's // arch/x86/kernel/signal.c:__setup_rt_frame().) -func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error { +func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error { sp := st.Bottom // "The 128-byte area beyond the location pointed to by %rsp is considered @@ -187,7 +177,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // Prior to proceeding, figure out if the frame will exhaust the range // for the signal stack. This is not allowed, and should immediately // force signal delivery (reverting to the default handler). - if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) { + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) { return unix.EFAULT } @@ -203,7 +193,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt return err } ucAddr := st.Bottom - if act.HasRestorer() { + if act.Flags&linux.SA_RESTORER != 0 { // Push the restorer return address. // Note that this doesn't need to be popped. if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil { @@ -237,15 +227,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // SignalRestore implements Context.SignalRestore. (Compare to Linux's // arch/x86/kernel/signal.c:sys_rt_sigreturn().) -func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { +func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) { // Copy out the stack frame. var uc UContext64 if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } - var info SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } // Restore registers. diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index da71fb873..80df90076 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -61,7 +61,7 @@ type FpsimdContext struct { type UContext64 struct { Flags uint64 Link uint64 - Stack SignalStack + Stack linux.SignalStack Sigset linux.SignalSet // glibc uses a 1024-bit sigset_t _pad [120]byte // (1024 - 64) / 8 = 120 @@ -71,18 +71,8 @@ type UContext64 struct { MContext SignalContext64 } -// NewSignalAct implements Context.NewSignalAct. -func (c *context64) NewSignalAct() NativeSignalAct { - return &SignalAct{} -} - -// NewSignalStack implements Context.NewSignalStack. -func (c *context64) NewSignalStack() NativeSignalStack { - return &SignalStack{} -} - // SignalSetup implements Context.SignalSetup. -func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error { +func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error { sp := st.Bottom // Construct the UContext64 now since we need its size. @@ -114,7 +104,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // Prior to proceeding, figure out if the frame will exhaust the range // for the signal stack. This is not allowed, and should immediately // force signal delivery (reverting to the default handler). - if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) { + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) { return unix.EFAULT } @@ -137,7 +127,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt c.Regs.Regs[0] = uint64(info.Signo) c.Regs.Regs[1] = uint64(infoAddr) c.Regs.Regs[2] = uint64(ucAddr) - c.Regs.Regs[30] = uint64(act.Restorer) + c.Regs.Regs[30] = act.Restorer // Save the thread's floating point state. c.sigFPState = append(c.sigFPState, c.fpState) @@ -147,15 +137,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt } // SignalRestore implements Context.SignalRestore. -func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { +func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) { // Copy out the stack frame. var uc UContext64 if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } - var info SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } // Restore registers. diff --git a/pkg/sentry/arch/signal_info.go b/pkg/sentry/arch/signal_info.go deleted file mode 100644 index f93ee8b46..000000000 --- a/pkg/sentry/arch/signal_info.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arch - -// Possible values for SignalInfo.Code. These values originate from the Linux -// kernel's include/uapi/asm-generic/siginfo.h. -const ( - // SignalInfoUser (properly SI_USER) indicates that a signal was sent from - // a kill() or raise() syscall. - SignalInfoUser = 0 - - // SignalInfoKernel (properly SI_KERNEL) indicates that the signal was sent - // by the kernel. - SignalInfoKernel = 0x80 - - // SignalInfoTimer (properly SI_TIMER) indicates that the signal was sent - // by an expired timer. - SignalInfoTimer = -2 - - // SignalInfoTkill (properly SI_TKILL) indicates that the signal was sent - // from a tkill() or tgkill() syscall. - SignalInfoTkill = -6 - - // CLD_* codes are only meaningful for SIGCHLD. - - // CLD_EXITED indicates that a task exited. - CLD_EXITED = 1 - - // CLD_KILLED indicates that a task was killed by a signal. - CLD_KILLED = 2 - - // CLD_DUMPED indicates that a task was killed by a signal and then dumped - // core. - CLD_DUMPED = 3 - - // CLD_TRAPPED indicates that a task was stopped by ptrace. - CLD_TRAPPED = 4 - - // CLD_STOPPED indicates that a thread group completed a group stop. - CLD_STOPPED = 5 - - // CLD_CONTINUED indicates that a group-stopped thread group was continued. - CLD_CONTINUED = 6 - - // SYS_* codes are only meaningful for SIGSYS. - - // SYS_SECCOMP indicates that a signal originates from seccomp. - SYS_SECCOMP = 1 - - // TRAP_* codes are only meaningful for SIGTRAP. - - // TRAP_BRKPT indicates a breakpoint trap. - TRAP_BRKPT = 1 -) diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go deleted file mode 100644 index c732c7503..000000000 --- a/pkg/sentry/arch/signal_stack.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build 386 amd64 arm64 - -package arch - -import ( - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/marshal" -) - -const ( - // SignalStackFlagOnStack is possible set on return from getaltstack, - // in order to indicate that the thread is currently on the alt stack. - SignalStackFlagOnStack = 1 - - // SignalStackFlagDisable is a flag to indicate the stack is disabled. - SignalStackFlagDisable = 2 -) - -// IsEnabled returns true iff this signal stack is marked as enabled. -func (s SignalStack) IsEnabled() bool { - return s.Flags&SignalStackFlagDisable == 0 -} - -// Top returns the stack's top address. -func (s SignalStack) Top() hostarch.Addr { - return hostarch.Addr(s.Addr + s.Size) -} - -// SetOnStack marks this signal stack as in use. -// -// Note that there is no corresponding ClearOnStack, and that this should only -// be called on copies that are serialized to userspace. -func (s *SignalStack) SetOnStack() { - s.Flags |= SignalStackFlagOnStack -} - -// Contains checks if the stack pointer is within this stack. -func (s *SignalStack) Contains(sp hostarch.Addr) bool { - return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size) -} - -// NativeSignalStack is a type that is equivalent to stack_t in the guest -// architecture. -type NativeSignalStack interface { - marshal.Marshallable - - // SerializeFrom copies the data in the host SignalStack s into this - // object. - SerializeFrom(s *SignalStack) - - // DeserializeTo copies the data in this object into the host SignalStack - // s. - DeserializeTo(s *SignalStack) -} diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 65a794c7c..85e3515af 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -45,7 +45,7 @@ type Stack struct { } // scratchBufLen is the default length of Stack.scratchBuf. The -// largest structs the stack regularly serializes are arch.SignalInfo +// largest structs the stack regularly serializes are linux.SignalInfo // and arch.UContext64. We'll set the default size as the larger of // the two, arch.UContext64. var scratchBufLen = (*UContext64)(nil).SizeBytes() diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 367849e75..221e98a01 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -99,6 +99,9 @@ type ExecArgs struct { // PIDNamespace is the pid namespace for the process being executed. PIDNamespace *kernel.PIDNamespace + + // Limits is the limit set for the process being executed. + Limits *limits.LimitSet } // String prints the arguments as a string. @@ -151,6 +154,10 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI if pidns == nil { pidns = proc.Kernel.RootPIDNamespace() } + limitSet := args.Limits + if limitSet == nil { + limitSet = limits.NewLimitSet() + } initArgs := kernel.CreateProcessArgs{ Filename: args.Filename, Argv: args.Argv, @@ -161,7 +168,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI Credentials: creds, FDTable: fdTable, Umask: 0022, - Limits: limits.NewLimitSet(), + Limits: limitSet, MaxSymlinkTraversals: linux.MaxSymlinkTraversals, UTSNamespace: proc.Kernel.RootUTSNamespace(), IPCNamespace: proc.Kernel.RootIPCNamespace(), diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go index 07a564e92..f8b7a60fc 100644 --- a/pkg/sentry/fs/gofer/cache_policy.go +++ b/pkg/sentry/fs/gofer/cache_policy.go @@ -139,7 +139,7 @@ func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child // Walk from parent to child again. // - // TODO(b/112031682): If we have a directory FD in the parent + // NOTE(b/112031682): If we have a directory FD in the parent // inodeOperations, then we can use fstatat(2) to get the inode // attributes instead of making this RPC. qids, f, mask, attr, err := parentIops.fileState.file.walkGetAttr(ctx, []string{name}) diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index b998fb75d..085aa6d61 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -77,6 +77,27 @@ func (*overcommitMemory) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandl }, 0 } +// +stateify savable +type maxMapCount struct{} + +// NeedsUpdate implements seqfile.SeqSource. +func (*maxMapCount) NeedsUpdate(int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource. +func (*maxMapCount) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + return []seqfile.SeqData{ + { + Buf: []byte("2147483647\n"), + Handle: (*maxMapCount)(nil), + }, + }, 0 +} + func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { h := hostname{ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), @@ -96,6 +117,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { children := map[string]*fs.Inode{ + "max_map_count": seqfile.NewSeqFileInode(ctx, &maxMapCount{}, msrc), "mmap_min_addr": seqfile.NewSeqFileInode(ctx, &mmapMinAddrData{p.k}, msrc), "overcommit_memory": seqfile.NewSeqFileInode(ctx, &overcommitMemory{}, msrc), } diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go index 6512e9cdb..fe9871bdd 100644 --- a/pkg/sentry/fsimpl/cgroupfs/base.go +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -133,6 +133,17 @@ func (c *cgroupInode) Controllers() []kernel.CgroupController { return c.fs.kcontrollers } +// tasks returns a snapshot of the tasks inside the cgroup. +func (c *cgroupInode) tasks() []*kernel.Task { + c.fs.tasksMu.RLock() + defer c.fs.tasksMu.RUnlock() + ts := make([]*kernel.Task, 0, len(c.ts)) + for t := range c.ts { + ts = append(ts, t) + } + return ts +} + // Enter implements kernel.CgroupImpl.Enter. func (c *cgroupInode) Enter(t *kernel.Task) { c.fs.tasksMu.Lock() @@ -163,10 +174,7 @@ func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error pgids := make(map[kernel.ThreadID]struct{}) - d.fs.tasksMu.RLock() - defer d.fs.tasksMu.RUnlock() - - for task := range d.ts { + for _, task := range d.tasks() { // Map dedups pgid, since iterating over all tasks produces multiple // entries for the group leaders. if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 { @@ -205,10 +213,7 @@ func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error { var pids []kernel.ThreadID - d.fs.tasksMu.RLock() - defer d.fs.tasksMu.RUnlock() - - for task := range d.ts { + for _, task := range d.tasks() { if pid := currPidns.IDOfTask(task); pid != 0 { pids = append(pids, pid) } diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index 54050de3c..05d7eb4ce 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -49,8 +49,9 @@ // // kernel.CgroupRegistry.mu // cgroupfs.filesystem.mu -// Task.mu -// cgroupfs.filesystem.tasksMu. +// kernel.TaskSet.mu +// kernel.Task.mu +// cgroupfs.filesystem.tasksMu. package cgroupfs import ( diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 91ec4a142..eb09d54c3 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -1194,11 +1194,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // Requires 9P support. - return syserror.EINVAL - } - + // Resolve newParent first to verify that it's on this Mount. var ds *[]*dentry fs.renameMu.Lock() defer fs.renameMuUnlockAndCheckCaching(ctx, &ds) @@ -1206,8 +1202,21 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return syserror.EINVAL + } + if fs.opts.interop == InteropModeShared && opts.Flags&linux.RENAME_NOREPLACE != 0 { + // Requires 9P support to synchronize with other remote filesystem + // users. + return syserror.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -1280,6 +1289,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } var replacedVFSD *vfs.Dentry if replaced != nil { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 21692d2ac..cf69e1b7a 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1282,9 +1282,12 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) } func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { - // We only support xattrs prefixed with "user." (see b/148380782). Currently, - // there is no need to expose any other xattrs through a gofer. - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + // Deny access to the "security" and "system" namespaces since applications + // may expect these to affect kernel behavior in unimplemented ways + // (b/148380782). Allow all other extended attributes to be passed through + // to the remote filesystem. This is inconsistent with Linux's 9p client, + // but consistent with other filesystems (e.g. FUSE). + if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) { return syserror.EOPNOTSUPP } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) @@ -1684,7 +1687,7 @@ func (d *dentry) setDeleted() { } func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { - if d.file.isNil() || !d.userXattrSupported() { + if d.file.isNil() { return nil, nil } xattrMap, err := d.file.listXattr(ctx, size) @@ -1693,10 +1696,7 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui } xattrs := make([]string, 0, len(xattrMap)) for x := range xattrMap { - // We only support xattrs in the user.* namespace. - if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) { - xattrs = append(xattrs, x) - } + xattrs = append(xattrs, x) } return xattrs, nil } @@ -1731,13 +1731,6 @@ func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name return d.file.removeXattr(ctx, name) } -// Extended attributes in the user.* namespace are only supported for regular -// files and directories. -func (d *dentry) userXattrSupported() bool { - filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType() - return filetype == linux.ModeRegular || filetype == linux.ModeDirectory -} - // Preconditions: // * !d.isSynthetic(). // * d.isRegularFile() || d.isDir(). diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index f50b0fb08..8fac53c60 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -635,12 +635,6 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - // Only RENAME_NOREPLACE is supported. - if opts.Flags&^linux.RENAME_NOREPLACE != 0 { - return syserror.EINVAL - } - noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 - fs.mu.Lock() defer fs.processDeferredDecRefs(ctx) defer fs.mu.Unlock() @@ -651,6 +645,13 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + // Only RENAME_NOREPLACE is supported. + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return syserror.EINVAL + } + noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 + mnt := rp.Mount() if mnt != oldParentVD.Mount() { return syserror.EXDEV diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 46c500427..6b6fa0bd5 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -1017,10 +1017,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - return syserror.EINVAL - } - + // Resolve newParent first to verify that it's on this Mount. var ds *[]*dentry fs.renameMu.Lock() defer fs.renameMuUnlockAndCheckDrop(ctx, &ds) @@ -1028,8 +1025,16 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return syserror.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -1093,6 +1098,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return err } if replaced != nil { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 88ab49048..2bc98a94f 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -55,6 +55,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k * }), }), "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "max_map_count": fs.newInode(ctx, root, 0444, newStaticFile("2147483647\n")), "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}), "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")), }), diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index c766164c7..b3f9d1010 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -17,7 +17,6 @@ go_library( "//pkg/fspath", "//pkg/hostarch", "//pkg/memutil", - "//pkg/metric", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 438840ae2..473b41cff 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -63,8 +62,6 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } - metric.CreateSentryMetrics() - kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, @@ -83,12 +80,8 @@ func Boot() (*kernel.Kernel, error) { } // Create timekeeper. - tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange()) - if err != nil { - return nil, fmt.Errorf("creating timekeeper: %v", err) - } + tk := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange()) tk.SetClocks(time.NewCalibratedClocks()) - k.SetTimekeeper(tk) creds := auth.NewRootCredentials(auth.NewRootUserNamespace()) @@ -97,6 +90,7 @@ func Boot() (*kernel.Kernel, error) { if err = k.Init(kernel.InitKernelArgs{ ApplicationCores: uint(runtime.GOMAXPROCS(-1)), FeatureSet: cpuid.HostFeatureSet(), + Timekeeper: tk, RootUserNamespace: creds.UserNamespace, Vdso: vdso, RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace), diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 766289e60..f0f4297ef 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -496,20 +496,24 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // TODO(b/145974740): Support renameat2 flags. - return syserror.EINVAL - } - - // Resolve newParent first to verify that it's on this Mount. + // Resolve newParentDir first to verify that it's on this Mount. fs.mu.Lock() defer fs.mu.Unlock() newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + // TODO(b/145974740): Support other renameat2 flags. + return syserror.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -556,6 +560,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } replaced, ok := newParentDir.childMap[newName] if ok { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedDir, ok := replaced.inode.impl.(*directory) if ok { if !renamed.inode.isDir() { @@ -815,7 +822,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si if err != nil { return nil, err } - return d.inode.listXattr(size) + return d.inode.listXattr(rp.Credentials(), size) } // GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 9ae25ce9e..6b4367c42 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -717,44 +717,63 @@ func (i *inode) touchCMtimeLocked() { atomic.StoreInt64(&i.ctime, now) } -func (i *inode) listXattr(size uint64) ([]string, error) { - return i.xattrs.ListXattr(size) +func checkXattrName(name string) error { + // Linux's tmpfs supports "security" and "trusted" xattr namespaces, and + // (depending on build configuration) POSIX ACL xattr namespaces + // ("system.posix_acl_access" and "system.posix_acl_default"). We don't + // support POSIX ACLs or the "security" namespace (b/148380782). + if strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { + return nil + } + // We support the "user" namespace because we have tests that depend on + // this feature. + if strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return nil + } + return syserror.EOPNOTSUPP +} + +func (i *inode) listXattr(creds *auth.Credentials, size uint64) ([]string, error) { + return i.xattrs.ListXattr(creds, size) } func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { - if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { + if err := checkXattrName(opts.Name); err != nil { return "", err } - return i.xattrs.GetXattr(opts) + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + kuid := auth.KUID(atomic.LoadUint32(&i.uid)) + kgid := auth.KGID(atomic.LoadUint32(&i.gid)) + if err := vfs.GenericCheckPermissions(creds, vfs.MayRead, mode, kuid, kgid); err != nil { + return "", err + } + return i.xattrs.GetXattr(creds, mode, kuid, opts) } func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error { - if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { + if err := checkXattrName(opts.Name); err != nil { return err } - return i.xattrs.SetXattr(opts) -} - -func (i *inode) removeXattr(creds *auth.Credentials, name string) error { - if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + kuid := auth.KUID(atomic.LoadUint32(&i.uid)) + kgid := auth.KGID(atomic.LoadUint32(&i.gid)) + if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil { return err } - return i.xattrs.RemoveXattr(name) + return i.xattrs.SetXattr(creds, mode, kuid, opts) } -func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { - // We currently only support extended attributes in the user.* and - // trusted.* namespaces. See b/148380782. - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { - return syserror.EOPNOTSUPP +func (i *inode) removeXattr(creds *auth.Credentials, name string) error { + if err := checkXattrName(name); err != nil { + return err } mode := linux.FileMode(atomic.LoadUint32(&i.mode)) kuid := auth.KUID(atomic.LoadUint32(&i.uid)) kgid := auth.KGID(atomic.LoadUint32(&i.gid)) - if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil { return err } - return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) + return i.xattrs.RemoveXattr(creds, mode, kuid, name) } // fileDescription is embedded by tmpfs implementations of @@ -807,7 +826,7 @@ func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { // ListXattr implements vfs.FileDescriptionImpl.ListXattr. func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { - return fd.inode().listXattr(size) + return fd.inode().listXattr(auth.CredentialsFromContext(ctx), size) } // GetXattr implements vfs.FileDescriptionImpl.GetXattr. diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a1ec6daab..a82d641da 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//pkg/sync:generic_seqatomic", + template = "//pkg/sync/seqatomic:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -218,6 +218,7 @@ go_library( ":uncaught_signal_go_proto", "//pkg/abi", "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/amutex", "//pkg/bits", "//pkg/bpf", diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 869e49ebc..12180351d 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//pkg/sync:generic_atomicptr", + template = "//pkg/sync/atomicptr:generic_atomicptr", types = { "Value": "Credentials", }, diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index f855f038b..6224a0cbd 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index dbbbaeeb0..5d584dc45 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -17,7 +17,6 @@ package fasync import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -125,9 +124,9 @@ func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) { if !permCheck { return } - signalInfo := &arch.SignalInfo{ + signalInfo := &linux.SignalInfo{ Signo: int32(linux.SIGIO), - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } if a.signal != 0 { signalInfo.Signo = int32(a.signal) diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index a75686cf3..6c31e082c 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//pkg/sync:generic_atomicptr", + template = "//pkg/sync/atomicptr:generic_atomicptr", types = { "Value": "bucket", }, diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index febe7fe50..352c36ba9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -143,12 +143,6 @@ type Kernel struct { // to CreateProcess, and is protected by extMu. globalInit *ThreadGroup - // realtimeClock is a ktime.Clock based on timekeeper's Realtime. - realtimeClock *timekeeperClock - - // monotonicClock is a ktime.Clock based on timekeeper's Monotonic. - monotonicClock *timekeeperClock - // syslog is the kernel log. syslog syslog @@ -306,6 +300,9 @@ type InitKernelArgs struct { // FeatureSet is the emulated CPU feature set. FeatureSet *cpuid.FeatureSet + // Timekeeper manages time for all tasks in the system. + Timekeeper *Timekeeper + // RootUserNamespace is the root user namespace. RootUserNamespace *auth.UserNamespace @@ -345,24 +342,18 @@ type InitKernelArgs struct { PIDNamespace *PIDNamespace } -// SetTimekeeper sets Kernel.timekeeper. SetTimekeeper must be called before -// Init. -func (k *Kernel) SetTimekeeper(tk *Timekeeper) { - k.timekeeper = tk -} - // Init initialize the Kernel with no tasks. // // Callers must manually set Kernel.Platform and call Kernel.SetMemoryFile -// and Kernel.SetTimekeeper before calling Init. +// before calling Init. func (k *Kernel) Init(args InitKernelArgs) error { if args.FeatureSet == nil { return fmt.Errorf("args.FeatureSet is nil") } - if k.timekeeper == nil { - return fmt.Errorf("timekeeper is nil") + if args.Timekeeper == nil { + return fmt.Errorf("args.Timekeeper is nil") } - if k.timekeeper.clocks == nil { + if args.Timekeeper.clocks == nil { return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()") } if args.RootUserNamespace == nil { @@ -373,6 +364,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { } k.featureSet = args.FeatureSet + k.timekeeper = args.Timekeeper k.tasks = newTaskSet(args.PIDNamespace) k.rootUserNamespace = args.RootUserNamespace k.rootUTSNamespace = args.RootUTSNamespace @@ -397,8 +389,6 @@ func (k *Kernel) Init(args InitKernelArgs) error { } k.extraAuxv = args.ExtraAuxv k.vdso = args.Vdso - k.realtimeClock = &timekeeperClock{tk: k.timekeeper, c: sentrytime.Realtime} - k.monotonicClock = &timekeeperClock{tk: k.timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() k.ptraceExceptions = make(map[*Task]*Task) @@ -531,6 +521,8 @@ func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) + // Save the timekeeper's state. + // Save the kernel state. kernelStart := time.Now() stats, err := state.Save(ctx, w, k) @@ -675,7 +667,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { +func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, timeReady chan struct{}, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -722,6 +714,11 @@ func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, cl log.Infof("Overall load took [%s]", time.Since(loadStart)) k.Timekeeper().SetClocks(clocks) + + if timeReady != nil { + close(timeReady) + } + if net != nil { net.Resume() } @@ -1103,7 +1100,7 @@ func (k *Kernel) Start() error { } k.started = true - k.cpuClockTicker = ktime.NewTimer(k.monotonicClock, newKernelCPUClockTicker(k)) + k.cpuClockTicker = ktime.NewTimer(k.timekeeper.monotonicClock, newKernelCPUClockTicker(k)) k.cpuClockTicker.Swap(ktime.Setting{ Enabled: true, Period: linux.ClockTick, @@ -1258,7 +1255,7 @@ func (k *Kernel) incRunningTasks() { // These cause very different value of cpuClock. But again, since // nothing was running while the ticker was disabled, those differences // don't matter. - setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now()) + setting, exp := k.cpuClockTickerSetting.At(k.timekeeper.monotonicClock.Now()) if exp > 0 { atomic.AddUint64(&k.cpuClock, exp) } @@ -1341,7 +1338,7 @@ func (k *Kernel) Unpause() { // context is used only for debugging to describe how the signal was received. // // Preconditions: Kernel must have an init process. -func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) { +func (k *Kernel) SendExternalSignal(info *linux.SignalInfo, context string) { k.extMu.Lock() defer k.extMu.Unlock() k.sendExternalSignal(info, context) @@ -1349,7 +1346,7 @@ func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) { // SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup. // This function doesn't skip signals like SendExternalSignal does. -func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error { +func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *linux.SignalInfo) error { k.extMu.Lock() defer k.extMu.Unlock() return tg.SendSignal(info) @@ -1357,7 +1354,7 @@ func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.Signa // SendContainerSignal sends the given signal to all processes inside the // namespace that match the given container ID. -func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { +func (k *Kernel) SendContainerSignal(cid string, info *linux.SignalInfo) error { k.extMu.Lock() defer k.extMu.Unlock() k.tasks.mu.RLock() @@ -1468,12 +1465,12 @@ func (k *Kernel) ApplicationCores() uint { // RealtimeClock returns the application CLOCK_REALTIME clock. func (k *Kernel) RealtimeClock() ktime.Clock { - return k.realtimeClock + return k.timekeeper.realtimeClock } // MonotonicClock returns the application CLOCK_MONOTONIC clock. func (k *Kernel) MonotonicClock() ktime.Clock { - return k.monotonicClock + return k.timekeeper.monotonicClock } // CPUClockNow returns the current value of k.cpuClock. @@ -1553,32 +1550,6 @@ func (k *Kernel) SetSaveError(err error) { } } -var _ tcpip.Clock = (*Kernel)(nil) - -// Now implements tcpip.Clock.NowNanoseconds. -func (k *Kernel) Now() time.Time { - nsec, err := k.timekeeper.GetTime(sentrytime.Realtime) - if err != nil { - panic("timekeeper.GetTime(sentrytime.Realtime): " + err.Error()) - } - return time.Unix(0, nsec) -} - -// NowMonotonic implements tcpip.Clock.NowMonotonic. -func (k *Kernel) NowMonotonic() tcpip.MonotonicTime { - nsec, err := k.timekeeper.GetTime(sentrytime.Monotonic) - if err != nil { - panic("timekeeper.GetTime(sentrytime.Monotonic): " + err.Error()) - } - var mt tcpip.MonotonicTime - return mt.Add(time.Duration(nsec) * time.Nanosecond) -} - -// AfterFunc implements tcpip.Clock.AfterFunc. -func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer { - return ktime.TcpipAfterFunc(k.realtimeClock, d, f) -} - // SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or // LoadFrom. func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) { @@ -1861,7 +1832,9 @@ func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) { return } t.mu.Lock() - t.enterCgroupLocked(root) + // A task can be in the cgroup if it has been created after the + // cgroup hierarchy was registered. + t.enterCgroupIfNotYetLocked(root) t.mu.Unlock() }) k.tasks.mu.RUnlock() diff --git a/pkg/sentry/kernel/pending_signals.go b/pkg/sentry/kernel/pending_signals.go index 77a35b788..af455c434 100644 --- a/pkg/sentry/kernel/pending_signals.go +++ b/pkg/sentry/kernel/pending_signals.go @@ -17,7 +17,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" - "gvisor.dev/gvisor/pkg/sentry/arch" ) const ( @@ -65,7 +64,7 @@ type pendingSignalQueue struct { type pendingSignal struct { // pendingSignalEntry links into a pendingSignalList. pendingSignalEntry - *arch.SignalInfo + *linux.SignalInfo // If timer is not nil, it is the IntervalTimer which sent this signal. timer *IntervalTimer @@ -75,7 +74,7 @@ type pendingSignal struct { // on failure (if the given signal's queue is full). // // Preconditions: info represents a valid signal. -func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bool { +func (p *pendingSignals) enqueue(info *linux.SignalInfo, timer *IntervalTimer) bool { sig := linux.Signal(info.Signo) q := &p.signals[sig.Index()] if sig.IsStandard() { @@ -93,7 +92,7 @@ func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bo // dequeue dequeues and returns any pending signal not masked by mask. If no // unmasked signals are pending, dequeue returns nil. -func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo { +func (p *pendingSignals) dequeue(mask linux.SignalSet) *linux.SignalInfo { // "Real-time signals are delivered in a guaranteed order. Multiple // real-time signals of the same type are delivered in the order they were // sent. If different real-time signals are sent to a process, they are @@ -111,7 +110,7 @@ func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo { return p.dequeueSpecific(linux.Signal(lowestPendingUnblockedBit + 1)) } -func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *arch.SignalInfo { +func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *linux.SignalInfo { q := &p.signals[sig.Index()] ps := q.pendingSignalList.Front() if ps == nil { diff --git a/pkg/sentry/kernel/pending_signals_state.go b/pkg/sentry/kernel/pending_signals_state.go index ca8b4e164..e77f1a254 100644 --- a/pkg/sentry/kernel/pending_signals_state.go +++ b/pkg/sentry/kernel/pending_signals_state.go @@ -14,13 +14,11 @@ package kernel -import ( - "gvisor.dev/gvisor/pkg/sentry/arch" -) +import "gvisor.dev/gvisor/pkg/abi/linux" // +stateify savable type savedPendingSignal struct { - si *arch.SignalInfo + si *linux.SignalInfo timer *IntervalTimer } diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 24e467e93..3fa5d1d2f 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -135,7 +135,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume v = math.MaxInt32 // Silently truncate. } // Copy result to userspace. - iocc := primitive.IOCopyContext{ + iocc := usermem.IOCopyContext{ IO: io, Ctx: ctx, Opts: usermem.IOOpts{ diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go index 2e861a5a8..d801a3d83 100644 --- a/pkg/sentry/kernel/posixtimer.go +++ b/pkg/sentry/kernel/posixtimer.go @@ -18,7 +18,6 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/syserror" ) @@ -97,7 +96,7 @@ func (it *IntervalTimer) ResumeTimer() { } // Preconditions: it.target's signal mutex must be locked. -func (it *IntervalTimer) updateDequeuedSignalLocked(si *arch.SignalInfo) { +func (it *IntervalTimer) updateDequeuedSignalLocked(si *linux.SignalInfo) { it.sigpending = false if it.sigorphan { return @@ -138,9 +137,9 @@ func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Settin it.sigpending = true it.sigorphan = false it.overrunCur += exp - 1 - si := &arch.SignalInfo{ + si := &linux.SignalInfo{ Signo: int32(it.signo), - Code: arch.SignalInfoTimer, + Code: linux.SI_TIMER, } si.SetTimerID(it.id) si.SetSigval(it.sigval) diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 57c7659e7..a6287fd6a 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -394,7 +393,7 @@ func (t *Task) ptraceTrapLocked(code int32) { t.trapStopPending = false t.tg.signalHandlers.mu.Unlock() t.ptraceCode = code - t.ptraceSiginfo = &arch.SignalInfo{ + t.ptraceSiginfo = &linux.SignalInfo{ Signo: int32(linux.SIGTRAP), Code: code, } @@ -402,7 +401,7 @@ func (t *Task) ptraceTrapLocked(code int32) { t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) if t.beginPtraceStopLocked() { tracer := t.Tracer() - tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP)) + tracer.signalStop(t, linux.CLD_TRAPPED, int32(linux.SIGTRAP)) tracer.tg.eventQueue.Notify(EventTraceeStop) } } @@ -542,9 +541,9 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error { // "Unlike PTRACE_ATTACH, PTRACE_SEIZE does not stop the process." - // ptrace(2) if !seize { - target.sendSignalLocked(&arch.SignalInfo{ + target.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGSTOP), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }, false /* group */) } // Undocumented Linux feature: If the tracee is already group-stopped (and @@ -586,7 +585,7 @@ func (t *Task) exitPtrace() { for target := range t.ptraceTracees { if target.ptraceOpts.ExitKill { target.tg.signalHandlers.mu.Lock() - target.sendSignalLocked(&arch.SignalInfo{ + target.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGKILL), }, false /* group */) target.tg.signalHandlers.mu.Unlock() @@ -652,7 +651,7 @@ func (t *Task) forgetTracerLocked() { // Preconditions: // * The signal mutex must be locked. // * The caller must be running on the task goroutine. -func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool { +func (t *Task) ptraceSignalLocked(info *linux.SignalInfo) bool { if linux.Signal(info.Signo) == linux.SIGKILL { return false } @@ -678,7 +677,7 @@ func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool { t.ptraceSiginfo = info t.Debugf("Entering signal-delivery-stop for signal %d", info.Signo) if t.beginPtraceStopLocked() { - tracer.signalStop(t, arch.CLD_TRAPPED, info.Signo) + tracer.signalStop(t, linux.CLD_TRAPPED, info.Signo) tracer.tg.eventQueue.Notify(EventTraceeStop) } return true @@ -829,7 +828,7 @@ func (t *Task) ptraceClone(kind ptraceCloneKind, child *Task, opts *CloneOptions if child.ptraceSeized { child.trapStopPending = true } else { - child.pendingSignals.enqueue(&arch.SignalInfo{ + child.pendingSignals.enqueue(&linux.SignalInfo{ Signo: int32(linux.SIGSTOP), }, nil) } @@ -893,9 +892,9 @@ func (t *Task) ptraceExec(oldTID ThreadID) { } t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() - t.sendSignalLocked(&arch.SignalInfo{ + t.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGTRAP), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }, false /* group */) } @@ -1228,7 +1227,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error { return err case linux.PTRACE_SETSIGINFO: - var info arch.SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(t, data); err != nil { return err } diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index a95e174a2..54ca43c2e 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -39,11 +39,11 @@ func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input { } } -func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *arch.SignalInfo { - si := &arch.SignalInfo{ +func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *linux.SignalInfo { + si := &linux.SignalInfo{ Signo: int32(linux.SIGSYS), Errno: errno, - Code: arch.SYS_SECCOMP, + Code: linux.SYS_SECCOMP, } si.SetCallAddr(uint64(ip)) si.SetSyscall(sysno) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 1d9edf118..47bb66b42 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -171,10 +171,10 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu // Map semaphores and map indexes in a registry are of the same size, // check map semaphores only here for the system limit. if len(r.semaphores) >= setsMax { - return nil, syserror.EINVAL + return nil, syserror.ENOSPC } if r.totalSems() > int(semsTotalMax-nsems) { - return nil, syserror.EINVAL + return nil, syserror.ENOSPC } // Finally create a new set. diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 0cd9e2533..ca9076406 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -16,7 +16,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" ) @@ -233,7 +232,7 @@ func (pg *ProcessGroup) Session() *Session { // SendSignal sends a signal to all processes inside the process group. It is // analagous to kernel/signal.c:kill_pgrp. -func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error { +func (pg *ProcessGroup) SendSignal(info *linux.SignalInfo) error { tasks := pg.originator.TaskSet() tasks.mu.RLock() defer tasks.mu.RUnlock() diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go index 2488ae7d5..e08474d25 100644 --- a/pkg/sentry/kernel/signal.go +++ b/pkg/sentry/kernel/signal.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" ) @@ -36,7 +35,7 @@ const SignalPanic = linux.SIGUSR2 // context is used only for debugging to differentiate these cases. // // Preconditions: Kernel must have an init process. -func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) { +func (k *Kernel) sendExternalSignal(info *linux.SignalInfo, context string) { switch linux.Signal(info.Signo) { case linux.SIGURG: // Sent by the Go 1.14+ runtime for asynchronous goroutine preemption. @@ -60,18 +59,18 @@ func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) { } // SignalInfoPriv returns a SignalInfo equivalent to Linux's SEND_SIG_PRIV. -func SignalInfoPriv(sig linux.Signal) *arch.SignalInfo { - return &arch.SignalInfo{ +func SignalInfoPriv(sig linux.Signal) *linux.SignalInfo { + return &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } } // SignalInfoNoInfo returns a SignalInfo equivalent to Linux's SEND_SIG_NOINFO. -func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo { - info := &arch.SignalInfo{ +func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *linux.SignalInfo { + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go index 768fda220..147cc41bb 100644 --- a/pkg/sentry/kernel/signal_handlers.go +++ b/pkg/sentry/kernel/signal_handlers.go @@ -16,7 +16,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" ) @@ -30,14 +29,14 @@ type SignalHandlers struct { mu sync.Mutex `state:"nosave"` // actions is the action to be taken upon receiving each signal. - actions map[linux.Signal]arch.SignalAct + actions map[linux.Signal]linux.SigAction } // NewSignalHandlers returns a new SignalHandlers specifying all default // actions. func NewSignalHandlers() *SignalHandlers { return &SignalHandlers{ - actions: make(map[linux.Signal]arch.SignalAct), + actions: make(map[linux.Signal]linux.SigAction), } } @@ -59,9 +58,9 @@ func (sh *SignalHandlers) CopyForExec() *SignalHandlers { sh.mu.Lock() defer sh.mu.Unlock() for sig, act := range sh.actions { - if act.Handler == arch.SignalActIgnore { - sh2.actions[sig] = arch.SignalAct{ - Handler: arch.SignalActIgnore, + if act.Handler == linux.SIG_IGN { + sh2.actions[sig] = linux.SigAction{ + Handler: linux.SIG_IGN, } } } @@ -73,15 +72,15 @@ func (sh *SignalHandlers) IsIgnored(sig linux.Signal) bool { sh.mu.Lock() defer sh.mu.Unlock() sa, ok := sh.actions[sig] - return ok && sa.Handler == arch.SignalActIgnore + return ok && sa.Handler == linux.SIG_IGN } // dequeueActionLocked returns the SignalAct that should be used to handle sig. // // Preconditions: sh.mu must be locked. -func (sh *SignalHandlers) dequeueAction(sig linux.Signal) arch.SignalAct { +func (sh *SignalHandlers) dequeueAction(sig linux.Signal) linux.SigAction { act := sh.actions[sig] - if act.IsResetHandler() { + if act.Flags&linux.SA_RESETHAND != 0 { delete(sh.actions, sig) } return act diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index be1371855..2e3b4488a 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -151,7 +150,7 @@ type Task struct { // which the SA_ONSTACK flag is set. // // signalStack is exclusive to the task goroutine. - signalStack arch.SignalStack + signalStack linux.SignalStack // signalQueue is a set of registered waiters for signal-related events. // @@ -395,7 +394,7 @@ type Task struct { // ptraceSiginfo is analogous to Linux's task_struct::last_siginfo. // // ptraceSiginfo is protected by the TaskSet mutex. - ptraceSiginfo *arch.SignalInfo + ptraceSiginfo *linux.SignalInfo // ptraceEventMsg is the value set by PTRACE_EVENT stops and returned to // the tracer by ptrace(PTRACE_GETEVENTMSG). @@ -853,15 +852,13 @@ func (t *Task) SetOOMScoreAdj(adj int32) error { return nil } -// UID returns t's uid. -// TODO(gvisor.dev/issue/170): This method is not namespaced yet. -func (t *Task) UID() uint32 { +// KUID returns t's kuid. +func (t *Task) KUID() uint32 { return uint32(t.Credentials().EffectiveKUID) } -// GID returns t's gid. -// TODO(gvisor.dev/issue/170): This method is not namespaced yet. -func (t *Task) GID() uint32 { +// KGID returns t's kgid. +func (t *Task) KGID() uint32 { return uint32(t.Credentials().EffectiveKGID) } diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go index 25d2504fa..7c138e80f 100644 --- a/pkg/sentry/kernel/task_cgroup.go +++ b/pkg/sentry/kernel/task_cgroup.go @@ -85,6 +85,14 @@ func (t *Task) enterCgroupLocked(c Cgroup) { c.Enter(t) } +// +checklocks:t.mu +func (t *Task) enterCgroupIfNotYetLocked(c Cgroup) { + if _, ok := t.cgroups[c]; ok { + return + } + t.enterCgroupLocked(c) +} + // LeaveCgroups removes t out from all its cgroups. func (t *Task) LeaveCgroups() { t.mu.Lock() diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index d9897e802..cf8571262 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -66,7 +66,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -181,7 +180,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tg.signalHandlers = t.tg.signalHandlers.CopyForExec() t.endStopCond.L = &t.tg.signalHandlers.mu // "Any alternate signal stack is not preserved (sigaltstack(2))." - execve(2) - t.signalStack = arch.SignalStack{Flags: arch.SignalStackFlagDisable} + t.signalStack = linux.SignalStack{Flags: linux.SS_DISABLE} // "The termination signal is reset to SIGCHLD (see clone(2))." t.tg.terminationSignal = linux.SIGCHLD // execed indicates that the process can no longer join a process group diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index b1af1a7ef..d115b8783 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -28,9 +28,9 @@ import ( "errors" "fmt" "strconv" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" @@ -50,6 +50,23 @@ type ExitStatus struct { Signo int } +func (es ExitStatus) String() string { + var b strings.Builder + if code := es.Code; code != 0 { + if b.Len() != 0 { + b.WriteByte(' ') + } + _, _ = fmt.Fprintf(&b, "Code=%d", code) + } + if signal := es.Signo; signal != 0 { + if b.Len() != 0 { + b.WriteByte(' ') + } + _, _ = fmt.Fprintf(&b, "Signal=%d", signal) + } + return b.String() +} + // Signaled returns true if the ExitStatus indicates that the exiting task or // thread group was killed by a signal. func (es ExitStatus) Signaled() bool { @@ -122,12 +139,12 @@ func (t *Task) killLocked() { if t.stop != nil && t.stop.Killable() { t.endInternalStopLocked() } - t.pendingSignals.enqueue(&arch.SignalInfo{ + t.pendingSignals.enqueue(&linux.SignalInfo{ Signo: int32(linux.SIGKILL), // Linux just sets SIGKILL in the pending signal bitmask without // enqueueing an actual siginfo, such that // kernel/signal.c:collect_signal() initializes si_code to SI_USER. - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }, nil) t.interrupt() } @@ -332,7 +349,7 @@ func (t *Task) exitThreadGroup() bool { // signalStop must be called with t's signal mutex unlocked. t.tg.signalHandlers.mu.Unlock() if notifyParent && t.tg.leader.parent != nil { - t.tg.leader.parent.signalStop(t, arch.CLD_STOPPED, int32(sig)) + t.tg.leader.parent.signalStop(t, linux.CLD_STOPPED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop) } return last @@ -353,7 +370,7 @@ func (t *Task) exitChildren() { continue } other.signalHandlers.mu.Lock() - other.leader.sendSignalLocked(&arch.SignalInfo{ + other.leader.sendSignalLocked(&linux.SignalInfo{ Signo: int32(linux.SIGKILL), }, true /* group */) other.signalHandlers.mu.Unlock() @@ -368,9 +385,9 @@ func (t *Task) exitChildren() { // wait for a parent to reap them.) for c := range t.children { if sig := c.ParentDeathSignal(); sig != 0 { - siginfo := &arch.SignalInfo{ + siginfo := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } siginfo.SetPID(int32(c.tg.pidns.tids[t])) siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) @@ -652,10 +669,10 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) { t.parent.tg.signalHandlers.mu.Lock() if t.tg.terminationSignal == linux.SIGCHLD || fromPtraceDetach { if act, ok := t.parent.tg.signalHandlers.actions[linux.SIGCHLD]; ok { - if act.Handler == arch.SignalActIgnore { + if act.Handler == linux.SIG_IGN { t.exitParentAcked = true signalParent = false - } else if act.Flags&arch.SignalFlagNoCldWait != 0 { + } else if act.Flags&linux.SA_NOCLDWAIT != 0 { t.exitParentAcked = true } } @@ -705,17 +722,17 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) { } // Preconditions: The TaskSet mutex must be locked. -func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.SignalInfo { - info := &arch.SignalInfo{ +func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *linux.SignalInfo { + info := &linux.SignalInfo{ Signo: int32(sig), } info.SetPID(int32(receiver.tg.pidns.tids[t])) info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) if t.exitStatus.Signaled() { - info.Code = arch.CLD_KILLED + info.Code = linux.CLD_KILLED info.SetStatus(int32(t.exitStatus.Signo)) } else { - info.Code = arch.CLD_EXITED + info.Code = linux.CLD_EXITED info.SetStatus(int32(t.exitStatus.Code)) } // TODO(b/72102453): Set utime, stime. diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go index bd5543d4e..c132c27ef 100644 --- a/pkg/sentry/kernel/task_image.go +++ b/pkg/sentry/kernel/task_image.go @@ -17,7 +17,7 @@ package kernel import ( "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -27,7 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" ) -var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) +var errNoSyscalls = syserr.New("no syscall table found", errno.ENOEXEC) // Auxmap contains miscellaneous data for the task. type Auxmap map[string]interface{} diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 9ba5f8d78..f142feab4 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -536,7 +536,7 @@ func (tg *ThreadGroup) updateCPUTimersEnabledLocked() { // appropriate for /proc/[pid]/status. func (t *Task) StateStatus() string { switch s := t.TaskGoroutineSchedInfo().State; s { - case TaskGoroutineNonexistent: + case TaskGoroutineNonexistent, TaskGoroutineRunningSys: t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() switch t.exitState { @@ -546,16 +546,16 @@ func (t *Task) StateStatus() string { return "X (dead)" default: // The task goroutine can't exit before passing through - // runExitNotify, so this indicates that the task has been created, - // but the task goroutine hasn't yet started. The Linux equivalent - // is struct task_struct::state == TASK_NEW + // runExitNotify, so if s == TaskGoroutineNonexistent, the task has + // been created but the task goroutine hasn't yet started. The + // Linux equivalent is struct task_struct::state == TASK_NEW // (kernel/fork.c:copy_process() => // kernel/sched/core.c:sched_fork()), but the TASK_NEW bit is // masked out by TASK_REPORT for /proc/[pid]/status, leaving only // TASK_RUNNING. return "R (running)" } - case TaskGoroutineRunningSys, TaskGoroutineRunningApp: + case TaskGoroutineRunningApp: return "R (running)" case TaskGoroutineBlockedInterruptible: return "S (sleeping)" diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index c2b9fc08f..8ca61ed48 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -86,7 +86,7 @@ var defaultActions = map[linux.Signal]SignalAction{ } // computeAction figures out what to do given a signal number -// and an arch.SignalAct. SIGSTOP always results in a SignalActionStop, +// and an linux.SigAction. SIGSTOP always results in a SignalActionStop, // and SIGKILL always results in a SignalActionTerm. // Signal 0 is always ignored as many programs use it for various internal functions // and don't expect it to do anything. @@ -97,7 +97,7 @@ var defaultActions = map[linux.Signal]SignalAction{ // 0, the default action is taken; // 1, the signal is ignored; // anything else, the function returns SignalActionHandler. -func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction { +func computeAction(sig linux.Signal, act linux.SigAction) SignalAction { switch sig { case linux.SIGSTOP: return SignalActionStop @@ -108,9 +108,9 @@ func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction { } switch act.Handler { - case arch.SignalActDefault: + case linux.SIG_DFL: return defaultActions[sig] - case arch.SignalActIgnore: + case linux.SIG_IGN: return SignalActionIgnore default: return SignalActionHandler @@ -127,7 +127,7 @@ var StopSignals = linux.MakeSignalSet(linux.SIGSTOP, linux.SIGTSTP, linux.SIGTTI // If there are no pending unmasked signals, dequeueSignalLocked returns nil. // // Preconditions: t.tg.signalHandlers.mu must be locked. -func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *arch.SignalInfo { +func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *linux.SignalInfo { if info := t.pendingSignals.dequeue(mask); info != nil { return info } @@ -155,7 +155,7 @@ func (t *Task) PendingSignals() linux.SignalSet { } // deliverSignal delivers the given signal and returns the following run state. -func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunState { +func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRunState { sigact := computeAction(linux.Signal(info.Signo), act) if t.haveSyscallReturn { @@ -172,7 +172,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS fallthrough case sre == syserror.ERESTART_RESTARTBLOCK: fallthrough - case (sre == syserror.ERESTARTSYS && !act.IsRestart()): + case (sre == syserror.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0): t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1))) default: @@ -236,7 +236,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS // deliverSignalToHandler changes the task's userspace state to enter the given // user-configured handler for the given signal. -func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) error { +func (t *Task) deliverSignalToHandler(info *linux.SignalInfo, act linux.SigAction) error { // Signal delivery to an application handler interrupts restartable // sequences. t.rseqInterrupt() @@ -248,8 +248,8 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // N.B. This is a *copy* of the alternate stack that the user's signal // handler expects to see in its ucontext (even if it's not in use). alt := t.signalStack - if act.IsOnStack() && alt.IsEnabled() { - alt.SetOnStack() + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() { + alt.Flags |= linux.SS_ONSTACK if !alt.Contains(sp) { sp = hostarch.Addr(alt.Top()) } @@ -289,7 +289,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // Add our signal mask. newMask := t.signalMask | act.Mask - if !act.IsNoDefer() { + if act.Flags&linux.SA_NODEFER == 0 { newMask |= linux.SignalSetOf(linux.Signal(info.Signo)) } t.SetSignalMask(newMask) @@ -326,7 +326,7 @@ func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) { // Preconditions: // * The caller must be running on the task goroutine. // * t.exitState < TaskExitZombie. -func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) { +func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*linux.SignalInfo, error) { // set is the set of signals we're interested in; invert it to get the set // of signals to block. mask := ^(set &^ UnblockableSignals) @@ -373,7 +373,7 @@ func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.S // syserror.EINVAL - The signal is not valid. // syserror.EAGAIN - THe signal is realtime, and cannot be queued. // -func (t *Task) SendSignal(info *arch.SignalInfo) error { +func (t *Task) SendSignal(info *linux.SignalInfo) error { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() t.tg.signalHandlers.mu.Lock() @@ -382,7 +382,7 @@ func (t *Task) SendSignal(info *arch.SignalInfo) error { } // SendGroupSignal sends the given signal to t's thread group. -func (t *Task) SendGroupSignal(info *arch.SignalInfo) error { +func (t *Task) SendGroupSignal(info *linux.SignalInfo) error { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() t.tg.signalHandlers.mu.Lock() @@ -392,7 +392,7 @@ func (t *Task) SendGroupSignal(info *arch.SignalInfo) error { // SendSignal sends the given signal to tg, using tg's leader to determine if // the signal is blocked. -func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error { +func (tg *ThreadGroup) SendSignal(info *linux.SignalInfo) error { tg.pidns.owner.mu.RLock() defer tg.pidns.owner.mu.RUnlock() tg.signalHandlers.mu.Lock() @@ -400,11 +400,11 @@ func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error { return tg.leader.sendSignalLocked(info, true /* group */) } -func (t *Task) sendSignalLocked(info *arch.SignalInfo, group bool) error { +func (t *Task) sendSignalLocked(info *linux.SignalInfo, group bool) error { return t.sendSignalTimerLocked(info, group, nil) } -func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *IntervalTimer) error { +func (t *Task) sendSignalTimerLocked(info *linux.SignalInfo, group bool, timer *IntervalTimer) error { if t.exitState == TaskExitDead { return syserror.ESRCH } @@ -572,9 +572,9 @@ func (t *Task) forceSignal(sig linux.Signal, unconditional bool) { func (t *Task) forceSignalLocked(sig linux.Signal, unconditional bool) { blocked := linux.SignalSetOf(sig)&t.signalMask != 0 act := t.tg.signalHandlers.actions[sig] - ignored := act.Handler == arch.SignalActIgnore + ignored := act.Handler == linux.SIG_IGN if blocked || ignored || unconditional { - act.Handler = arch.SignalActDefault + act.Handler = linux.SIG_DFL t.tg.signalHandlers.actions[sig] = act if blocked { t.setSignalMaskLocked(t.signalMask &^ linux.SignalSetOf(sig)) @@ -641,17 +641,17 @@ func (t *Task) SetSavedSignalMask(mask linux.SignalSet) { } // SignalStack returns the task-private signal stack. -func (t *Task) SignalStack() arch.SignalStack { +func (t *Task) SignalStack() linux.SignalStack { t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) alt := t.signalStack if t.onSignalStack(alt) { - alt.Flags |= arch.SignalStackFlagOnStack + alt.Flags |= linux.SS_ONSTACK } return alt } // onSignalStack returns true if the task is executing on the given signal stack. -func (t *Task) onSignalStack(alt arch.SignalStack) bool { +func (t *Task) onSignalStack(alt linux.SignalStack) bool { sp := hostarch.Addr(t.Arch().Stack()) return alt.Contains(sp) } @@ -661,30 +661,30 @@ func (t *Task) onSignalStack(alt arch.SignalStack) bool { // This value may not be changed if the task is currently executing on the // signal stack, i.e. if t.onSignalStack returns true. In this case, this // function will return false. Otherwise, true is returned. -func (t *Task) SetSignalStack(alt arch.SignalStack) bool { +func (t *Task) SetSignalStack(alt linux.SignalStack) bool { // Check that we're not executing on the stack. if t.onSignalStack(t.signalStack) { return false } - if alt.Flags&arch.SignalStackFlagDisable != 0 { + if alt.Flags&linux.SS_DISABLE != 0 { // Don't record anything beyond the flags. - t.signalStack = arch.SignalStack{ - Flags: arch.SignalStackFlagDisable, + t.signalStack = linux.SignalStack{ + Flags: linux.SS_DISABLE, } } else { // Mask out irrelevant parts: only disable matters. - alt.Flags &= arch.SignalStackFlagDisable + alt.Flags &= linux.SS_DISABLE t.signalStack = alt } return true } -// SetSignalAct atomically sets the thread group's signal action for signal sig +// SetSigAction atomically sets the thread group's signal action for signal sig // to *actptr (if actptr is not nil) and returns the old signal action. -func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (arch.SignalAct, error) { +func (tg *ThreadGroup) SetSigAction(sig linux.Signal, actptr *linux.SigAction) (linux.SigAction, error) { if !sig.IsValid() { - return arch.SignalAct{}, syserror.EINVAL + return linux.SigAction{}, syserror.EINVAL } tg.pidns.owner.mu.RLock() @@ -718,48 +718,6 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a return oldact, nil } -// CopyOutSignalAct converts the given SignalAct into an architecture-specific -// type and then copies it out to task memory. -func (t *Task) CopyOutSignalAct(addr hostarch.Addr, s *arch.SignalAct) error { - n := t.Arch().NewSignalAct() - n.SerializeFrom(s) - _, err := n.CopyOut(t, addr) - return err -} - -// CopyInSignalAct copies an architecture-specific sigaction type from task -// memory and then converts it into a SignalAct. -func (t *Task) CopyInSignalAct(addr hostarch.Addr) (arch.SignalAct, error) { - n := t.Arch().NewSignalAct() - var s arch.SignalAct - if _, err := n.CopyIn(t, addr); err != nil { - return s, err - } - n.DeserializeTo(&s) - return s, nil -} - -// CopyOutSignalStack converts the given SignalStack into an -// architecture-specific type and then copies it out to task memory. -func (t *Task) CopyOutSignalStack(addr hostarch.Addr, s *arch.SignalStack) error { - n := t.Arch().NewSignalStack() - n.SerializeFrom(s) - _, err := n.CopyOut(t, addr) - return err -} - -// CopyInSignalStack copies an architecture-specific stack_t from task memory -// and then converts it into a SignalStack. -func (t *Task) CopyInSignalStack(addr hostarch.Addr) (arch.SignalStack, error) { - n := t.Arch().NewSignalStack() - var s arch.SignalStack - if _, err := n.CopyIn(t, addr); err != nil { - return s, err - } - n.DeserializeTo(&s) - return s, nil -} - // groupStop is a TaskStop placed on tasks that have received a stop signal // (SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU). (The term "group-stop" originates from // the ptrace man page.) @@ -774,7 +732,7 @@ func (*groupStop) Killable() bool { return true } // previously-dequeued stop signal. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) initiateGroupStop(info *arch.SignalInfo) { +func (t *Task) initiateGroupStop(info *linux.SignalInfo) { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() t.tg.signalHandlers.mu.Lock() @@ -909,8 +867,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() act, ok := t.tg.signalHandlers.actions[linux.SIGCHLD] - if !ok || (act.Handler != arch.SignalActIgnore && act.Flags&arch.SignalFlagNoCldStop == 0) { - sigchld := &arch.SignalInfo{ + if !ok || (act.Handler != linux.SIG_IGN && act.Flags&linux.SA_NOCLDSTOP == 0) { + sigchld := &linux.SignalInfo{ Signo: int32(linux.SIGCHLD), Code: code, } @@ -955,14 +913,14 @@ func (*runInterrupt) execute(t *Task) taskRunState { // notified its tracer accordingly. But it's consistent with // Linux... if intr { - tracer.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig)) + tracer.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig)) if !notifyParent { tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop | EventChildGroupStop) } else { tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop) } } else { - tracer.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig)) + tracer.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig)) tracer.tg.eventQueue.Notify(EventGroupContinue) } } @@ -974,10 +932,10 @@ func (*runInterrupt) execute(t *Task) taskRunState { // SIGCHLD is a standard signal, so the latter would always be // dropped. Hence sending only the former is equivalent. if intr { - t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig)) + t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue | EventChildGroupStop) } else { - t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig)) + t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue) } } @@ -1018,7 +976,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { // without requiring an extra PTRACE_GETSIGINFO call." - // "Group-stop", ptrace(2) t.ptraceCode = int32(sig) | linux.PTRACE_EVENT_STOP<<8 - t.ptraceSiginfo = &arch.SignalInfo{ + t.ptraceSiginfo = &linux.SignalInfo{ Signo: int32(sig), Code: t.ptraceCode, } @@ -1029,7 +987,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { t.ptraceSiginfo = nil } if t.beginPtraceStopLocked() { - tracer.signalStop(t, arch.CLD_STOPPED, int32(sig)) + tracer.signalStop(t, linux.CLD_STOPPED, int32(sig)) // For consistency with Linux, if the parent and tracer are in the // same thread group, deduplicate notification signals. if notifyParent && tracer.tg == t.tg.leader.parent.tg { @@ -1047,7 +1005,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { t.tg.signalHandlers.mu.Unlock() } if notifyParent { - t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig)) + t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig)) t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop) } t.tg.pidns.owner.mu.RUnlock() @@ -1101,7 +1059,7 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { if sig != linux.Signal(info.Signo) { info.Signo = int32(sig) info.Errno = 0 - info.Code = arch.SignalInfoUser + info.Code = linux.SI_USER // pid isn't a valid field for all signal numbers, but Linux // doesn't care (kernel/signal.c:ptrace_signal()). // diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 32031cd70..41fd2d471 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" @@ -131,7 +130,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { runState: (*runApp)(nil), interruptChan: make(chan struct{}, 1), signalMask: cfg.SignalMask, - signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable}, + signalStack: linux.SignalStack{Flags: linux.SS_DISABLE}, image: *image, fsContext: cfg.FSContext, fdTable: cfg.FDTable, diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index b92e98fa1..4566e4c7c 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -279,7 +278,7 @@ func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, s limits: limits, mounts: mntns, } - tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg}) + tg.itimerRealTimer = ktime.NewTimer(k.timekeeper.monotonicClock, &itimerRealListener{tg: tg}) tg.timers = make(map[linux.TimerID]*IntervalTimer) tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{}) return tg @@ -446,10 +445,10 @@ func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error { othertg.signalHandlers.mu.Lock() othertg.tty = nil if othertg.processGroup == tg.processGroup.session.foreground { - if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { + if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { lastErr = err } - if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { + if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { lastErr = err } } @@ -490,10 +489,10 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) tg.signalHandlers.mu.Lock() defer tg.signalHandlers.mu.Unlock() - // TODO(b/129283598): "If tcsetpgrp() is called by a member of a - // background process group in its session, and the calling process is - // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all - // members of this background process group." + // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a + // background process group in its session, and the calling process is not + // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of + // this background process group." // tty must be the controlling terminal. if tg.tty != tty { diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 7c4fefb16..6255bae7a 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" ) // Timekeeper manages all of the kernel clocks. @@ -39,6 +40,12 @@ type Timekeeper struct { // It is set only once, by SetClocks. clocks sentrytime.Clocks `state:"nosave"` + // realtimeClock is a ktime.Clock based on timekeeper's Realtime. + realtimeClock *timekeeperClock + + // monotonicClock is a ktime.Clock based on timekeeper's Monotonic. + monotonicClock *timekeeperClock + // bootTime is the realtime when the system "booted". i.e., when // SetClocks was called in the initial (not restored) run. bootTime ktime.Time @@ -90,10 +97,13 @@ type Timekeeper struct { // NewTimekeeper does not take ownership of paramPage. // // SetClocks must be called on the returned Timekeeper before it is usable. -func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) { - return &Timekeeper{ +func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) *Timekeeper { + t := Timekeeper{ params: NewVDSOParamPage(mfp, paramPage), - }, nil + } + t.realtimeClock = &timekeeperClock{tk: &t, c: sentrytime.Realtime} + t.monotonicClock = &timekeeperClock{tk: &t, c: sentrytime.Monotonic} + return &t } // SetClocks the backing clock source. @@ -167,6 +177,32 @@ func (t *Timekeeper) SetClocks(c sentrytime.Clocks) { } } +var _ tcpip.Clock = (*Timekeeper)(nil) + +// Now implements tcpip.Clock. +func (t *Timekeeper) Now() time.Time { + nsec, err := t.GetTime(sentrytime.Realtime) + if err != nil { + panic("timekeeper.GetTime(sentrytime.Realtime): " + err.Error()) + } + return time.Unix(0, nsec) +} + +// NowMonotonic implements tcpip.Clock. +func (t *Timekeeper) NowMonotonic() tcpip.MonotonicTime { + nsec, err := t.GetTime(sentrytime.Monotonic) + if err != nil { + panic("timekeeper.GetTime(sentrytime.Monotonic): " + err.Error()) + } + var mt tcpip.MonotonicTime + return mt.Add(time.Duration(nsec) * time.Nanosecond) +} + +// AfterFunc implements tcpip.Clock. +func (t *Timekeeper) AfterFunc(d time.Duration, f func()) tcpip.Timer { + return ktime.TcpipAfterFunc(t.realtimeClock, d, f) +} + // startUpdater starts an update goroutine that keeps the clocks updated. // // mu must be held. diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 4c65215fa..80f862628 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -17,6 +17,7 @@ go_library( deps = [ "//pkg/abi", "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/context", "//pkg/cpuid", "//pkg/hostarch", diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go index 3886b4d33..3e302d92c 100644 --- a/pkg/sentry/loader/interpreter.go +++ b/pkg/sentry/loader/interpreter.go @@ -59,7 +59,7 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil // Linux silently truncates the remainder of the line if it exceeds // interpMaxLineLength. i := bytes.IndexByte(line, '\n') - if i > 0 { + if i >= 0 { line = line[:i] } diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 47e3775a3..8240173ae 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" @@ -237,7 +238,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V // loaded.end is available for its use. e, ok := loaded.end.RoundUp() if !ok { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), errno.ENOEXEC) } args.MemoryManager.BrkSetup(ctx, e) diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index b81292c46..d1a883da4 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -1062,10 +1062,20 @@ func (f *MemoryFile) runReclaim() { break } - // If ManualZeroing is in effect, pages will be zeroed on allocation - // and may not be freed by decommitFile, so calling decommitFile is - // unnecessary. - if !f.opts.ManualZeroing { + if f.opts.ManualZeroing { + // If ManualZeroing is in effect, only hugepage-aligned regions may + // be safely passed to decommitFile. Pages will be zeroed on + // reallocation, so we don't need to perform any manual zeroing + // here, whether or not decommitFile succeeds. + if startAddr, ok := hostarch.Addr(fr.Start).HugeRoundUp(); ok { + if endAddr := hostarch.Addr(fr.End).HugeRoundDown(); startAddr < endAddr { + decommitFR := memmap.FileRange{uint64(startAddr), uint64(endAddr)} + if err := f.decommitFile(decommitFR); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", decommitFR, err) + } + } + } + } else { if err := f.decommitFile(fr); err != nil { log.Warningf("Reclaim failed to decommit %v: %v", fr, err) // Zero the pages manually. This won't reduce memory usage, but at diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 77079cae9..8a490b3de 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -79,6 +79,7 @@ go_test( "requires-kvm", ], deps = [ + "//pkg/abi/linux", "//pkg/hostarch", "//pkg/ring0", "//pkg/ring0/pagetables", diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go index f4d4473a8..183e741ea 100644 --- a/pkg/sentry/platform/kvm/context.go +++ b/pkg/sentry/platform/kvm/context.go @@ -17,6 +17,7 @@ package kvm import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" pkgcontext "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" @@ -32,15 +33,15 @@ type context struct { // machine is the parent machine, and is immutable. machine *machine - // info is the arch.SignalInfo cached for this context. - info arch.SignalInfo + // info is the linux.SignalInfo cached for this context. + info linux.SignalInfo // interrupt is the interrupt context. interrupt interrupt.Forwarder } // Switch runs the provided context in the given address space. -func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, hostarch.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*linux.SignalInfo, hostarch.AccessType, error) { as := mm.AddressSpace() localAS := as.(*addressSpace) diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index b8dd1e4a5..b1cab89a0 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -19,6 +19,7 @@ package kvm import ( "testing" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -30,7 +31,7 @@ func TestSegments(t *testing.T) { applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestSegments(regs) for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -55,7 +56,7 @@ func stmxcsr(addr *uint32) func TestMXCSR(t *testing.T) { applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo switchOpts := ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index ceff09a60..fe570aff9 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -22,6 +22,7 @@ import ( "time" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" @@ -157,7 +158,7 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func func TestApplicationSyscall(t *testing.T) { applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -171,7 +172,7 @@ func TestApplicationSyscall(t *testing.T) { return false }) applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -188,7 +189,7 @@ func TestApplicationSyscall(t *testing.T) { func TestApplicationFault(t *testing.T) { applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -203,7 +204,7 @@ func TestApplicationFault(t *testing.T) { }) applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, nil) // Cause fault. - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -221,7 +222,7 @@ func TestRegistersSyscall(t *testing.T) { applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestRegs(regs) // Fill values for all registers. for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -244,7 +245,7 @@ func TestRegistersFault(t *testing.T) { applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTestRegs(regs) // Fill values for all registers. for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -270,7 +271,7 @@ func TestBounce(t *testing.T) { time.Sleep(time.Millisecond) c.BounceToKernel() }() - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -285,7 +286,7 @@ func TestBounce(t *testing.T) { time.Sleep(time.Millisecond) c.BounceToKernel() }() - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -317,7 +318,7 @@ func TestBounceStress(t *testing.T) { c.BounceToKernel() }() randomSleep() - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -338,7 +339,7 @@ func TestInvalidate(t *testing.T) { applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { testutil.SetTouchTarget(regs, &data) // Read legitimate value. for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -353,7 +354,7 @@ func TestInvalidate(t *testing.T) { // Unmap the page containing data & invalidate. pt.Unmap(hostarch.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(hostarch.PageSize-1)), hostarch.PageSize) for { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -371,13 +372,13 @@ func TestInvalidate(t *testing.T) { } // IsFault returns true iff the given signal represents a fault. -func IsFault(err error, si *arch.SignalInfo) bool { +func IsFault(err error, si *linux.SignalInfo) bool { return err == platform.ErrContextSignal && si.Signo == int32(unix.SIGSEGV) } func TestEmptyAddressSpace(t *testing.T) { applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -391,7 +392,7 @@ func TestEmptyAddressSpace(t *testing.T) { return false }) applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -467,7 +468,7 @@ func BenchmarkApplicationSyscall(b *testing.B) { a int // Count for ErrContextInterrupt. ) applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, @@ -504,7 +505,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) { a int ) applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - var si arch.SignalInfo + var si linux.SignalInfo if _, err := c.SwitchToUser(ring0.SwitchOpts{ Registers: regs, FloatingPointState: &dummyFPState, diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 9a2337654..7a10fd812 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -23,11 +23,11 @@ import ( "runtime/debug" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" @@ -264,10 +264,10 @@ func (c *vCPU) setSystemTime() error { // nonCanonical generates a canonical address return. // //go:nosplit -func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { - *info = arch.SignalInfo{ +func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { + *info = linux.SignalInfo{ Signo: signal, - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } info.SetAddr(addr) // Include address. return hostarch.NoAccess, platform.ErrContextSignal @@ -276,7 +276,7 @@ func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.Ac // fault generates an appropriate fault return. // //go:nosplit -func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { bluepill(c) // Probably no-op, but may not be. faultAddr := ring0.ReadCR2() code, user := c.ErrorCode() @@ -287,7 +287,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, return hostarch.NoAccess, platform.ErrContextInterrupt } // Reset the pointed SignalInfo. - *info = arch.SignalInfo{Signo: signal} + *info = linux.SignalInfo{Signo: signal} info.SetAddr(uint64(faultAddr)) accessType := hostarch.AccessType{ Read: code&(1<<1) == 0, @@ -325,7 +325,7 @@ func prefaultFloatingPointState(data *fpu.State) { } // SwitchToUser unpacks architectural-details. -func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) { return nonCanonical(regs.Rip, int32(unix.SIGSEGV), info) @@ -371,7 +371,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return c.fault(int32(unix.SIGSEGV), info) case ring0.Debug, ring0.Breakpoint: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGTRAP), Code: 1, // TRAP_BRKPT (breakpoint). } @@ -383,9 +383,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) ring0.BoundRangeExceeded, ring0.InvalidTSS, ring0.StackSegmentFault: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGSEGV), - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } info.SetAddr(switchOpts.Registers.Rip) // Include address. if vector == ring0.GeneralProtectionFault { @@ -397,7 +397,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.AccessType{}, platform.ErrContextSignal case ring0.InvalidOpcode: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGILL), Code: 1, // ILL_ILLOPC (illegal opcode). } @@ -405,7 +405,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.AccessType{}, platform.ErrContextSignal case ring0.DivideByZero: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGFPE), Code: 1, // FPE_INTDIV (divide by zero). } @@ -413,7 +413,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.AccessType{}, platform.ErrContextSignal case ring0.Overflow: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGFPE), Code: 2, // FPE_INTOVF (integer overflow). } @@ -422,7 +422,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.X87FloatingPointException, ring0.SIMDFloatingPointException: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGFPE), Code: 7, // FPE_FLTINV (invalid operation). } @@ -433,7 +433,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return hostarch.NoAccess, platform.ErrContextInterrupt case ring0.AlignmentCheck: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGBUS), Code: 2, // BUS_ADRERR (physical address does not exist). } @@ -469,7 +469,7 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { } func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { - // Map all the executible regions so that all the entry functions + // Map all the executable regions so that all the entry functions // are mapped in the upper half. applyVirtualRegions(func(vr virtualRegion) { if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { @@ -485,7 +485,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { pageTable.Map( hostarch.Addr(ring0.KernelStartAddress|r.virtual), r.length, - pagetables.MapOpts{AccessType: hostarch.Execute}, + pagetables.MapOpts{AccessType: hostarch.Execute, Global: true}, physical) } }) @@ -498,7 +498,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { pageTable.Map( hostarch.Addr(ring0.KernelStartAddress|start), regionLen, - pagetables.MapOpts{AccessType: hostarch.ReadWrite}, + pagetables.MapOpts{AccessType: hostarch.ReadWrite, Global: true}, physical) } } diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 8926b1d9f..edaccf9bc 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -21,10 +21,10 @@ import ( "sync/atomic" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ) @@ -126,10 +126,10 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { // nonCanonical generates a canonical address return. // //go:nosplit -func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { - *info = arch.SignalInfo{ +func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { + *info = linux.SignalInfo{ Signo: signal, - Code: arch.SignalInfoKernel, + Code: linux.SI_KERNEL, } info.SetAddr(addr) // Include address. return hostarch.NoAccess, platform.ErrContextSignal @@ -157,7 +157,7 @@ func isWriteFault(code uint64) bool { // fault generates an appropriate fault return. // //go:nosplit -func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) { bluepill(c) // Probably no-op, but may not be. faultAddr := c.GetFaultAddr() code, user := c.ErrorCode() @@ -170,7 +170,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, } // Reset the pointed SignalInfo. - *info = arch.SignalInfo{Signo: signal} + *info = linux.SignalInfo{Signo: signal} info.SetAddr(uint64(faultAddr)) ret := code & _ESR_ELx_FSC diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index d2a6d81bc..f6aa519b1 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -23,10 +23,10 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" ktime "gvisor.dev/gvisor/pkg/sentry/time" @@ -265,7 +265,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { } // SwitchToUser unpacks architectural-details. -func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) { +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) { // Check for canonical addresses. if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { return nonCanonical(regs.Pc, int32(unix.SIGSEGV), info) @@ -312,14 +312,14 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.El0SyncUndef: return c.fault(int32(unix.SIGILL), info) case ring0.El0SyncDbg: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGTRAP), Code: 1, // TRAP_BRKPT (breakpoint). } info.SetAddr(switchOpts.Registers.Pc) // Include address. return hostarch.AccessType{}, platform.ErrContextSignal case ring0.El0SyncSpPc: - *info = arch.SignalInfo{ + *info = linux.SignalInfo{ Signo: int32(unix.SIGBUS), Code: 2, // BUS_ADRERR (physical address does not exist). } diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index ef7814a6f..a26bc2316 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -195,8 +195,8 @@ type Context interface { // - nil: The Context invoked a system call. // // - ErrContextSignal: The Context was interrupted by a signal. The - // returned *arch.SignalInfo contains information about the signal. If - // arch.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType + // returned *linux.SignalInfo contains information about the signal. If + // linux.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType // contains the access type of the triggering fault. The caller owns // the returned SignalInfo. // @@ -207,7 +207,7 @@ type Context interface { // concurrent call to Switch(). // // - ErrContextCPUPreempted: See the definition of that error for details. - Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) + Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error) // PullFullState() pulls a full state of the application thread. // diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 828458ce2..319b0cf1d 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -73,7 +73,7 @@ var ( type context struct { // signalInfo is the signal info, if and when a signal is received. - signalInfo arch.SignalInfo + signalInfo linux.SignalInfo // interrupt is the interrupt context. interrupt interrupt.Forwarder @@ -96,7 +96,7 @@ type context struct { } // Switch runs the provided context in the given address space. -func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) { +func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error) { as := mm.AddressSpace() s := as.(*subprocess) isSyscall := s.switchToApp(c, ac) diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index facb96011..cc93396a9 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -101,7 +101,7 @@ func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) erro } // getSignalInfo retrieves information about the signal that caused the stop. -func (t *thread) getSignalInfo(si *arch.SignalInfo) error { +func (t *thread) getSignalInfo(si *linux.SignalInfo) error { _, _, errno := unix.RawSyscall6( unix.SYS_PTRACE, unix.PTRACE_GETSIGINFO, diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 9c73a725a..0931795c5 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -20,6 +20,7 @@ import ( "runtime" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" @@ -524,7 +525,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // Check for interrupts, and ensure that future interrupts will signal t. if !c.interrupt.Enable(t) { // Pending interrupt; simulate. - c.signalInfo = arch.SignalInfo{Signo: int32(platform.SignalInterrupt)} + c.signalInfo = linux.SignalInfo{Signo: int32(platform.SignalInterrupt)} return false } defer c.interrupt.Disable() diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 9252c0bd7..90b1ead56 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -155,7 +155,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) { // // Note that this should only be called after verifying that the signalInfo has // been generated by the kernel. -func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) { +func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) { if linux.Signal(signalInfo.Signo) == linux.SIGSYS { signalInfo.Signo = int32(linux.SIGSEGV) diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index c0cbc0686..e4257e3bf 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -138,7 +138,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) { // // Note that this should only be called after verifying that the signalInfo has // been generated by the kernel. -func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) { +func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) { if linux.Signal(signalInfo.Signo) == linux.SIGSYS { signalInfo.Signo = int32(linux.SIGSEGV) diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index d6a2fbe34..3fe5c6770 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -21,25 +21,16 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in -// pkg/sentry/arch. -type sigaction struct { - handler uintptr - flags uint64 - restorer uintptr - mask uint64 -} - // IgnoreChildStop sets the SA_NOCLDSTOP flag, causing child processes to not // generate SIGCHLD when they stop. func IgnoreChildStop() error { - var sa sigaction + var sa linux.SigAction // Get the existing signal handler information, and set the flag. if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), 0, uintptr(unsafe.Pointer(&sa)), linux.SignalSetSize, 0, 0); e != 0 { return e } - sa.flags |= linux.SA_NOCLDSTOP + sa.Flags |= linux.SA_NOCLDSTOP if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 { return e } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 52ae4bc9c..b9473da6c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -67,23 +67,6 @@ type socketOperations struct { socketOpsCommon } -// socketOpsCommon contains the socket operations common to VFS1 and VFS2. -// -// +stateify savable -type socketOpsCommon struct { - socket.SendReceiveTimeout - - family int // Read-only. - stype linux.SockType // Read-only. - protocol int // Read-only. - queue waiter.Queue - - // fd is the host socket fd. It must have O_NONBLOCK, so that operations - // will return EWOULDBLOCK instead of blocking on the host. This allows us to - // handle blocking behavior independently in the sentry. - fd int -} - var _ = socket.Socket(&socketOperations{}) func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { @@ -103,29 +86,6 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil } -// Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release(context.Context) { - fdnotifier.RemoveFD(int32(s.fd)) - unix.Close(s.fd) -} - -// Readiness implements waiter.Waitable.Readiness. -func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { - return fdnotifier.NonBlockingPoll(int32(s.fd), mask) -} - -// EventRegister implements waiter.Waitable.EventRegister. -func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - s.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(s.fd)) -} - -// EventUnregister implements waiter.Waitable.EventUnregister. -func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { - s.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(s.fd)) -} - // Ioctl implements fs.FileOperations.Ioctl. func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { return ioctl(ctx, s.fd, io, args) @@ -177,6 +137,96 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), err } +// Socket implements socket.Provider.Socket. +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { + // Check that we are using the host network stack. + stack := t.NetworkContext() + if stack == nil { + return nil, nil + } + if _, ok := stack.(*Stack); !ok { + return nil, nil + } + + // Only accept TCP and UDP. + stype := stypeflags & linux.SOCK_TYPE_MASK + switch stype { + case unix.SOCK_STREAM: + switch protocol { + case 0, unix.IPPROTO_TCP: + // ok + default: + return nil, nil + } + case unix.SOCK_DGRAM: + switch protocol { + case 0, unix.IPPROTO_UDP: + // ok + default: + return nil, nil + } + default: + return nil, nil + } + + // Conservatively ignore all flags specified by the application and add + // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 + // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. + fd, err := unix.Socket(p.family, int(stype)|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0) + if err != nil { + return nil, syserr.FromError(err) + } + return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&unix.SOCK_NONBLOCK != 0) +} + +// Pair implements socket.Provider.Pair. +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { + // Not supported by AF_INET/AF_INET6. + return nil, nil, nil +} + +// LINT.ThenChange(./socket_vfs2.go) + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { + socket.SendReceiveTimeout + + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + queue waiter.Queue + + // fd is the host socket fd. It must have O_NONBLOCK, so that operations + // will return EWOULDBLOCK instead of blocking on the host. This allows us to + // handle blocking behavior independently in the sentry. + fd int +} + +// Release implements fs.FileOperations.Release. +func (s *socketOpsCommon) Release(context.Context) { + fdnotifier.RemoveFD(int32(s.fd)) + unix.Close(s.fd) +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { + return fdnotifier.NonBlockingPoll(int32(s.fd), mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.queue.EventRegister(e, mask) + fdnotifier.UpdateFD(int32(s.fd)) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { + s.queue.EventUnregister(e) + fdnotifier.UpdateFD(int32(s.fd)) +} + // Connect implements socket.Socket.Connect. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { if len(sockaddr) > sizeofSockaddr { @@ -596,6 +646,17 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b return 0, syserr.ErrInvalidArgument } + // If the src is zero-length, call SENDTO directly with a null buffer in + // order to generate poll/epoll notifications. + if src.NumBytes() == 0 { + sysflags := flags | unix.MSG_DONTWAIT + n, _, errno := unix.Syscall6(unix.SYS_SENDTO, uintptr(s.fd), 0, 0, uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to))) + if errno != 0 { + return 0, syserr.FromError(errno) + } + return int(n), nil + } + space := uint64(control.CmsgsSpace(t, controlMessages)) if space > maxControlLen { space = maxControlLen @@ -709,56 +770,6 @@ type socketProvider struct { family int } -// Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { - // Check that we are using the host network stack. - stack := t.NetworkContext() - if stack == nil { - return nil, nil - } - if _, ok := stack.(*Stack); !ok { - return nil, nil - } - - // Only accept TCP and UDP. - stype := stypeflags & linux.SOCK_TYPE_MASK - switch stype { - case unix.SOCK_STREAM: - switch protocol { - case 0, unix.IPPROTO_TCP: - // ok - default: - return nil, nil - } - case unix.SOCK_DGRAM: - switch protocol { - case 0, unix.IPPROTO_UDP: - // ok - default: - return nil, nil - } - default: - return nil, nil - } - - // Conservatively ignore all flags specified by the application and add - // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 - // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. - fd, err := unix.Socket(p.family, int(stype)|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0) - if err != nil { - return nil, syserr.FromError(err) - } - return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&unix.SOCK_NONBLOCK != 0) -} - -// Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { - // Not supported by AF_INET/AF_INET6. - return nil, nil, nil -} - -// LINT.ThenChange(./socket_vfs2.go) - func init() { for _, family := range []int{unix.AF_INET, unix.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 61b2c9755..608474fa1 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -25,6 +25,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 6fc7781ad..3f1b4a17b 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,20 +19,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// TODO(gvisor.dev/issue/170): The following per-matcher params should be -// supported: -// - Table name -// - Match size -// - User size -// - Hooks -// - Proto -// - Family - // matchMaker knows how to (un)marshal the matcher named name(). type matchMaker interface { // name is the matcher name as stored in the xt_entry_match struct. @@ -43,7 +35,7 @@ type matchMaker interface { // unmarshal converts from the ABI matcher struct to an // stack.Matcher. - unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) + unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) } type matcher interface { @@ -94,12 +86,12 @@ func marshalEntryMatch(name string, data []byte) []byte { return buf } -func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { +func unmarshalMatcher(task *kernel.Task, match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { matchMaker, ok := matchMakers[match.Name.String()] if !ok { return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String()) } - return matchMaker.unmarshal(buf, filter) + return matchMaker.unmarshal(task, buf, filter) } // targetMaker knows how to (un)marshal a target. Once registered, diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index cb78ef60b..d8bd86292 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -123,7 +124,7 @@ func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTG return entries, info } -func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { +func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { nflog("set entries: setting entries in table %q", replace.Name.String()) // Convert input into a list of rules and their offsets. @@ -148,23 +149,19 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): We should support more IPTIP - // filtering fields. filter, err := filterFromIPTIP(entry.IP) if err != nil { nflog("bad iptip: %v", err) return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Matchers and targets can specify - // that they only work for certain protocols, hooks, tables. // Get matchers. matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry if len(optVal) < int(matchersSize) { nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - matchers, err := parseMatchers(filter, optVal[:matchersSize]) + matchers, err := parseMatchers(task, filter, optVal[:matchersSize]) if err != nil { nflog("failed to parse matchers: %v", err) return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 5cb7fe4aa..c68230847 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -126,7 +127,7 @@ func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6T return entries, info } -func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { +func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { nflog("set entries: setting entries in table %q", replace.Name.String()) // Convert input into a list of rules and their offsets. @@ -151,23 +152,19 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): We should support more IPTIP - // filtering fields. filter, err := filterFromIP6TIP(entry.IPv6) if err != nil { nflog("bad iptip: %v", err) return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Matchers and targets can specify - // that they only work for certain protocols, hooks, tables. // Get matchers. matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry if len(optVal) < int(matchersSize) { nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - matchers, err := parseMatchers(filter, optVal[:matchersSize]) + matchers, err := parseMatchers(task, filter, optVal[:matchersSize]) if err != nil { nflog("failed to parse matchers: %v", err) return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index e1c4b06fc..e3eade180 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -174,13 +174,12 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. -func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { +func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] replace.UnmarshalBytes(replaceBuf) - // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { case filterTable: @@ -188,16 +187,16 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { case natTable: table = stack.EmptyNATTable() default: - nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) + nflog("unknown iptables table %q", replace.Name.String()) return syserr.ErrInvalidArgument } var err *syserr.Error var offsets map[uint32]int if ipv6 { - offsets, err = modifyEntries6(stk, optVal, &replace, &table) + offsets, err = modifyEntries6(task, stk, optVal, &replace, &table) } else { - offsets, err = modifyEntries4(stk, optVal, &replace, &table) + offsets, err = modifyEntries4(task, stk, optVal, &replace, &table) } if err != nil { return err @@ -272,7 +271,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { table.Rules[ruleIdx] = rule } - // TODO(gvisor.dev/issue/170): Support other chains. // Since we don't support FORWARD, yet, make sure all other chains point to // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { @@ -287,7 +285,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } } - // TODO(gvisor.dev/issue/170): Check the following conditions: + // TODO(gvisor.dev/issue/6167): Check the following conditions: // - There are no loops. // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. @@ -297,7 +295,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // parseMatchers parses 0 or more matchers from optVal. optVal should contain // only the matchers. -func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) { +func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) { nflog("set entries: parsing matchers of size %d", len(optVal)) var matchers []stack.Matcher for len(optVal) > 0 { @@ -321,13 +319,13 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, } // Parse the specific matcher. - matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize]) + matcher, err := unmarshalMatcher(task, match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize]) if err != nil { return nil, fmt.Errorf("failed to create matcher: %v", err) } matchers = append(matchers, matcher) - // TODO(gvisor.dev/issue/170): Check the revision field. + // TODO(gvisor.dev/issue/6167): Check the revision field. optVal = optVal[match.MatchSize:] } diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 60845cab3..6eff2ae65 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -40,8 +42,8 @@ func (ownerMarshaler) name() string { func (ownerMarshaler) marshal(mr matcher) []byte { matcher := mr.(*OwnerMatcher) iptOwnerInfo := linux.IPTOwnerInfo{ - UID: matcher.uid, - GID: matcher.gid, + UID: uint32(matcher.uid), + GID: uint32(matcher.gid), } // Support for UID and GID match. @@ -63,7 +65,7 @@ func (ownerMarshaler) marshal(mr matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { +func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfIPTOwnerInfo { return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf)) } @@ -72,11 +74,12 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) - nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) + nflog("parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher - owner.uid = matchData.UID - owner.gid = matchData.GID + creds := task.Credentials() + owner.uid = creds.UserNamespace.MapToKUID(auth.UID(matchData.UID)) + owner.gid = creds.UserNamespace.MapToKGID(auth.GID(matchData.GID)) // Check flags. if matchData.Match&linux.XT_OWNER_UID != 0 { @@ -97,8 +100,8 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // OwnerMatcher matches against a UID and/or GID. type OwnerMatcher struct { - uid uint32 - gid uint32 + uid auth.KUID + gid auth.KGID matchUID bool matchGID bool invertUID bool @@ -113,7 +116,6 @@ func (*OwnerMatcher) name() string { // Match implements Matcher.Match. func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // Support only for OUTPUT chain. - // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { return false, true } @@ -126,7 +128,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str var matches bool // Check for UID match. if om.matchUID { - if pkt.Owner.UID() == om.uid { + if auth.KUID(pkt.Owner.KUID()) == om.uid { matches = true } if matches == om.invertUID { @@ -137,7 +139,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str // Check for GID match. if om.matchGID { matches = false - if pkt.Owner.GID() == om.gid { + if auth.KGID(pkt.Owner.KGID()) == om.gid { matches = true } if matches == om.invertGID { diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index fa5456eee..ea56f39c1 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -331,7 +331,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Check if the flags are valid. // Also check if we need to map ports or IP. // For now, redirect target only supports destination port change. // Port range and IP range are not supported yet. @@ -340,7 +339,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Port range is not supported yet. if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) return nil, syserr.ErrInvalidArgument @@ -420,7 +418,6 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/3549): Check for other flags. // For now, redirect target only supports destination change. if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags) @@ -502,7 +499,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta return nil, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/170): Port range is not supported yet. if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) return nil, syserr.ErrInvalidArgument @@ -594,7 +590,6 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { - // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: return &acceptTarget{stack.AcceptTarget{ diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 95bb9826e..e5b73a976 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -50,7 +51,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { +func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTTCP { return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf)) } @@ -95,8 +96,6 @@ func (*TCPMatcher) name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved - // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: netHeader := header.IPv4(pkt.NetworkHeader().View()) diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index fb8be27e6..aa72ee70c 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -50,7 +51,7 @@ func (udpMarshaler) marshal(mr matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { +func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTUDP { return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf)) } @@ -92,8 +93,6 @@ func (*UDPMatcher) name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved - // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: netHeader := header.IPv4(pkt.NetworkHeader().View()) diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 64cd263da..6b83698ad 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -14,6 +14,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/bits", "//pkg/context", "//pkg/hostarch", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 280563d09..c9f784cf4 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -20,6 +20,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -56,7 +57,7 @@ const ( maxSendBufferSize = 4 << 20 // 4MB ) -var errNoFilter = syserr.New("no filter attached", linux.ENOENT) +var errNoFilter = syserr.New("no filter attached", errno.ENOENT) // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 9561b7c25..96c425619 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -19,6 +19,7 @@ go_library( ], deps = [ "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/context", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 037ccfec8..66d0fcb47 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -36,6 +36,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" @@ -289,7 +290,7 @@ const DefaultTTL = 64 const sizeOfInt32 int = 4 -var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) +var errStackType = syserr.New("expected but did not receive a netstack.Stack", errno.EINVAL) // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. @@ -1798,11 +1799,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := hostarch.ByteOrder.Uint32(optVal) - - if v == 0 { - socket.SetSockOptEmitUnimplementedEvent(t, name) - } - ep.SocketOptions().SetOutOfBandInline(v != 0) return nil @@ -2110,10 +2106,10 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true) + return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true) case linux.IP6T_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. + log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported") return nil default: @@ -2353,10 +2349,10 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false) + return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false) case linux.IPT_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. + log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported") return nil case linux.IP_ADD_SOURCE_MEMBERSHIP, diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 353f4ade0..f5da3c509 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -659,7 +659,6 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) @@ -749,7 +748,6 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 167754537..2f0aba4e2 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -110,7 +110,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { +func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, timeReady chan struct{}, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -120,5 +120,5 @@ func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, c previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(ctx, r, n, clocks, vfsOpts) + return k.LoadFrom(ctx, r, timeReady, n, clocks, vfsOpts) } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 1fbbd133c..369541c7a 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -11,6 +11,7 @@ go_library( "futex.go", "linux64_amd64.go", "linux64_arm64.go", + "mmap.go", "open.go", "poll.go", "ptrace.go", @@ -35,7 +36,6 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/netlink", "//pkg/sentry/syscalls/linux", - "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/strace/clone.go b/pkg/sentry/strace/clone.go index ab1060426..bfb4d7f5c 100644 --- a/pkg/sentry/strace/clone.go +++ b/pkg/sentry/strace/clone.go @@ -15,98 +15,98 @@ package strace import ( - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" ) // CloneFlagSet is the set of clone(2) flags. var CloneFlagSet = abi.FlagSet{ { - Flag: unix.CLONE_VM, + Flag: linux.CLONE_VM, Name: "CLONE_VM", }, { - Flag: unix.CLONE_FS, + Flag: linux.CLONE_FS, Name: "CLONE_FS", }, { - Flag: unix.CLONE_FILES, + Flag: linux.CLONE_FILES, Name: "CLONE_FILES", }, { - Flag: unix.CLONE_SIGHAND, + Flag: linux.CLONE_SIGHAND, Name: "CLONE_SIGHAND", }, { - Flag: unix.CLONE_PTRACE, + Flag: linux.CLONE_PTRACE, Name: "CLONE_PTRACE", }, { - Flag: unix.CLONE_VFORK, + Flag: linux.CLONE_VFORK, Name: "CLONE_VFORK", }, { - Flag: unix.CLONE_PARENT, + Flag: linux.CLONE_PARENT, Name: "CLONE_PARENT", }, { - Flag: unix.CLONE_THREAD, + Flag: linux.CLONE_THREAD, Name: "CLONE_THREAD", }, { - Flag: unix.CLONE_NEWNS, + Flag: linux.CLONE_NEWNS, Name: "CLONE_NEWNS", }, { - Flag: unix.CLONE_SYSVSEM, + Flag: linux.CLONE_SYSVSEM, Name: "CLONE_SYSVSEM", }, { - Flag: unix.CLONE_SETTLS, + Flag: linux.CLONE_SETTLS, Name: "CLONE_SETTLS", }, { - Flag: unix.CLONE_PARENT_SETTID, + Flag: linux.CLONE_PARENT_SETTID, Name: "CLONE_PARENT_SETTID", }, { - Flag: unix.CLONE_CHILD_CLEARTID, + Flag: linux.CLONE_CHILD_CLEARTID, Name: "CLONE_CHILD_CLEARTID", }, { - Flag: unix.CLONE_DETACHED, + Flag: linux.CLONE_DETACHED, Name: "CLONE_DETACHED", }, { - Flag: unix.CLONE_UNTRACED, + Flag: linux.CLONE_UNTRACED, Name: "CLONE_UNTRACED", }, { - Flag: unix.CLONE_CHILD_SETTID, + Flag: linux.CLONE_CHILD_SETTID, Name: "CLONE_CHILD_SETTID", }, { - Flag: unix.CLONE_NEWUTS, + Flag: linux.CLONE_NEWUTS, Name: "CLONE_NEWUTS", }, { - Flag: unix.CLONE_NEWIPC, + Flag: linux.CLONE_NEWIPC, Name: "CLONE_NEWIPC", }, { - Flag: unix.CLONE_NEWUSER, + Flag: linux.CLONE_NEWUSER, Name: "CLONE_NEWUSER", }, { - Flag: unix.CLONE_NEWPID, + Flag: linux.CLONE_NEWPID, Name: "CLONE_NEWPID", }, { - Flag: unix.CLONE_NEWNET, + Flag: linux.CLONE_NEWNET, Name: "CLONE_NEWNET", }, { - Flag: unix.CLONE_IO, + Flag: linux.CLONE_IO, Name: "CLONE_IO", }, } diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index d66befe81..6ce1bb592 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -33,7 +33,7 @@ var linuxAMD64 = SyscallMap{ 6: makeSyscallInfo("lstat", Path, Stat), 7: makeSyscallInfo("poll", PollFDs, Hex, Hex), 8: makeSyscallInfo("lseek", Hex, Hex, Hex), - 9: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex), + 9: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex), 10: makeSyscallInfo("mprotect", Hex, Hex, Hex), 11: makeSyscallInfo("munmap", Hex, Hex), 12: makeSyscallInfo("brk", Hex), diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index 1a2d7d75f..ce5594301 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -246,7 +246,7 @@ var linuxARM64 = SyscallMap{ 219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex), 220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex), 221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector), - 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex), + 222: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex), 223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex), 224: makeSyscallInfo("swapon", Hex, Hex), 225: makeSyscallInfo("swapoff", Hex), diff --git a/pkg/sentry/strace/mmap.go b/pkg/sentry/strace/mmap.go new file mode 100644 index 000000000..0035be586 --- /dev/null +++ b/pkg/sentry/strace/mmap.go @@ -0,0 +1,92 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// ProtectionFlagSet represents the protection to mmap(2). +var ProtectionFlagSet = abi.FlagSet{ + { + Flag: linux.PROT_READ, + Name: "PROT_READ", + }, + { + Flag: linux.PROT_WRITE, + Name: "PROT_WRITE", + }, + { + Flag: linux.PROT_EXEC, + Name: "PROT_EXEC", + }, +} + +// MmapFlagSet is the set of mmap(2) flags. +var MmapFlagSet = abi.FlagSet{ + { + Flag: linux.MAP_SHARED, + Name: "MAP_SHARED", + }, + { + Flag: linux.MAP_PRIVATE, + Name: "MAP_PRIVATE", + }, + { + Flag: linux.MAP_FIXED, + Name: "MAP_FIXED", + }, + { + Flag: linux.MAP_ANONYMOUS, + Name: "MAP_ANONYMOUS", + }, + { + Flag: linux.MAP_GROWSDOWN, + Name: "MAP_GROWSDOWN", + }, + { + Flag: linux.MAP_DENYWRITE, + Name: "MAP_DENYWRITE", + }, + { + Flag: linux.MAP_EXECUTABLE, + Name: "MAP_EXECUTABLE", + }, + { + Flag: linux.MAP_LOCKED, + Name: "MAP_LOCKED", + }, + { + Flag: linux.MAP_NORESERVE, + Name: "MAP_NORESERVE", + }, + { + Flag: linux.MAP_POPULATE, + Name: "MAP_POPULATE", + }, + { + Flag: linux.MAP_NONBLOCK, + Name: "MAP_NONBLOCK", + }, + { + Flag: linux.MAP_STACK, + Name: "MAP_STACK", + }, + { + Flag: linux.MAP_HUGETLB, + Name: "MAP_HUGETLB", + }, +} diff --git a/pkg/sentry/strace/open.go b/pkg/sentry/strace/open.go index 5769360da..e7c7649f4 100644 --- a/pkg/sentry/strace/open.go +++ b/pkg/sentry/strace/open.go @@ -15,61 +15,61 @@ package strace import ( - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" ) // OpenMode represents the mode to open(2) a file. var OpenMode = abi.ValueSet{ - unix.O_RDWR: "O_RDWR", - unix.O_WRONLY: "O_WRONLY", - unix.O_RDONLY: "O_RDONLY", + linux.O_RDWR: "O_RDWR", + linux.O_WRONLY: "O_WRONLY", + linux.O_RDONLY: "O_RDONLY", } // OpenFlagSet is the set of open(2) flags. var OpenFlagSet = abi.FlagSet{ { - Flag: unix.O_APPEND, + Flag: linux.O_APPEND, Name: "O_APPEND", }, { - Flag: unix.O_ASYNC, + Flag: linux.O_ASYNC, Name: "O_ASYNC", }, { - Flag: unix.O_CLOEXEC, + Flag: linux.O_CLOEXEC, Name: "O_CLOEXEC", }, { - Flag: unix.O_CREAT, + Flag: linux.O_CREAT, Name: "O_CREAT", }, { - Flag: unix.O_DIRECT, + Flag: linux.O_DIRECT, Name: "O_DIRECT", }, { - Flag: unix.O_DIRECTORY, + Flag: linux.O_DIRECTORY, Name: "O_DIRECTORY", }, { - Flag: unix.O_EXCL, + Flag: linux.O_EXCL, Name: "O_EXCL", }, { - Flag: unix.O_NOATIME, + Flag: linux.O_NOATIME, Name: "O_NOATIME", }, { - Flag: unix.O_NOCTTY, + Flag: linux.O_NOCTTY, Name: "O_NOCTTY", }, { - Flag: unix.O_NOFOLLOW, + Flag: linux.O_NOFOLLOW, Name: "O_NOFOLLOW", }, { - Flag: unix.O_NONBLOCK, + Flag: linux.O_NONBLOCK, Name: "O_NONBLOCK", }, { @@ -77,18 +77,22 @@ var OpenFlagSet = abi.FlagSet{ Name: "O_PATH", }, { - Flag: unix.O_SYNC, + Flag: linux.O_SYNC, Name: "O_SYNC", }, { - Flag: unix.O_TRUNC, + Flag: linux.O_TMPFILE, + Name: "O_TMPFILE", + }, + { + Flag: linux.O_TRUNC, Name: "O_TRUNC", }, } func open(val uint64) string { - s := OpenMode.Parse(val & unix.O_ACCMODE) - if flags := OpenFlagSet.Parse(val &^ unix.O_ACCMODE); flags != "" { + s := OpenMode.Parse(val & linux.O_ACCMODE) + if flags := OpenFlagSet.Parse(val &^ linux.O_ACCMODE); flags != "" { s += "|" + flags } return s diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go index e5b379a20..5afc9525b 100644 --- a/pkg/sentry/strace/signal.go +++ b/pkg/sentry/strace/signal.go @@ -130,8 +130,8 @@ func sigAction(t *kernel.Task, addr hostarch.Addr) string { return "null" } - sa, err := t.CopyInSignalAct(addr) - if err != nil { + var sa linux.SigAction + if _, err := sa.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error copying sigaction: %v)", addr, err) } diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index ec5d5f846..af7088847 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -489,6 +489,10 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize))) case SelectFDSet: output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) + case MmapProt: + output = append(output, ProtectionFlagSet.Parse(uint64(args[arg].Uint()))) + case MmapFlags: + output = append(output, MmapFlagSet.Parse(uint64(args[arg].Uint()))) case Oct: output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8)) case Hex: diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 7e69b9279..5893443a7 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -238,6 +238,12 @@ const ( // EpollEvents is an array of struct epoll_event. It is the events // argument in epoll_wait(2)/epoll_pwait(2). EpollEvents + + // MmapProt is the protection argument in mmap(2). + MmapProt + + // MmapFlags is the flags argument in mmap(2). + MmapFlags ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 37443ab78..90a719ba2 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1569,9 +1569,9 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } @@ -1632,9 +1632,9 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } @@ -2140,9 +2140,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EFBIG } if uint64(size) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 53b12dc41..27a7f7fe1 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -84,9 +84,9 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if !mayKill(t, target, sig) { return 0, nil, syserror.EPERM } - info := &arch.SignalInfo{ + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) @@ -123,9 +123,9 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // depend on the iteration order. We at least implement the // semantics documented by the man page: "On success (at least // one signal was sent), zero is returned." - info := &arch.SignalInfo{ + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) @@ -167,9 +167,9 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC continue } - info := &arch.SignalInfo{ + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, } info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) @@ -184,10 +184,10 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC } } -func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalInfo { - info := &arch.SignalInfo{ +func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *linux.SignalInfo { + info := &linux.SignalInfo{ Signo: int32(sig), - Code: arch.SignalInfoTkill, + Code: linux.SI_TKILL, } info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) @@ -251,20 +251,20 @@ func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EINVAL } - var newactptr *arch.SignalAct + var newactptr *linux.SigAction if newactarg != 0 { - newact, err := t.CopyInSignalAct(newactarg) - if err != nil { + var newact linux.SigAction + if _, err := newact.CopyIn(t, newactarg); err != nil { return 0, nil, err } newactptr = &newact } - oldact, err := t.ThreadGroup().SetSignalAct(sig, newactptr) + oldact, err := t.ThreadGroup().SetSigAction(sig, newactptr) if err != nil { return 0, nil, err } if oldactarg != 0 { - if err := t.CopyOutSignalAct(oldactarg, &oldact); err != nil { + if _, err := oldact.CopyOut(t, oldactarg); err != nil { return 0, nil, err } } @@ -325,13 +325,12 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S alt := t.SignalStack() if oldaddr != 0 { - if err := t.CopyOutSignalStack(oldaddr, &alt); err != nil { + if _, err := alt.CopyOut(t, oldaddr); err != nil { return 0, nil, err } } if setaddr != 0 { - alt, err := t.CopyInSignalStack(setaddr) - if err != nil { + if _, err := alt.CopyIn(t, setaddr); err != nil { return 0, nil, err } // The signal stack cannot be changed if the task is currently @@ -410,7 +409,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne // We must ensure that the Signo is set (Linux overrides this in the // same way), and that the code is in the allowed set. This same logic // appears below in RtSigtgqueueinfo and should be kept in sync. - var info arch.SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(t, infoAddr); err != nil { return 0, nil, err } @@ -426,7 +425,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne // If the sender is not the receiver, it can't use si_codes used by the // kernel or SI_TKILL. - if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t { + if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t { return 0, nil, syserror.EPERM } @@ -454,7 +453,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker } // Copy in the info. See RtSigqueueinfo above. - var info arch.SignalInfo + var info linux.SignalInfo if _, err := info.CopyIn(t, infoAddr); err != nil { return 0, nil, err } @@ -469,7 +468,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker // If the sender is not the receiver, it can't use si_codes used by the // kernel or SI_TKILL. - if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t { + if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t { return 0, nil, syserror.EPERM } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 3185ea527..0d5056303 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -398,7 +398,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // out the fields it would set for a successful waitid in this case // as well. if infop != 0 { - var si arch.SignalInfo + var si linux.SignalInfo _, err = si.CopyOut(t, infop) } } @@ -413,7 +413,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if infop == 0 { return 0, nil, nil } - si := arch.SignalInfo{ + si := linux.SignalInfo{ Signo: int32(linux.SIGCHLD), } si.SetPID(int32(wr.TID)) @@ -423,24 +423,24 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal s := unix.WaitStatus(wr.Status) switch { case s.Exited(): - si.Code = arch.CLD_EXITED + si.Code = linux.CLD_EXITED si.SetStatus(int32(s.ExitStatus())) case s.Signaled(): - si.Code = arch.CLD_KILLED + si.Code = linux.CLD_KILLED si.SetStatus(int32(s.Signal())) case s.CoreDump(): - si.Code = arch.CLD_DUMPED + si.Code = linux.CLD_DUMPED si.SetStatus(int32(s.Signal())) case s.Stopped(): if wr.Event == kernel.EventTraceeStop { - si.Code = arch.CLD_TRAPPED + si.Code = linux.CLD_TRAPPED si.SetStatus(int32(s.TrapCause())) } else { - si.Code = arch.CLD_STOPPED + si.Code = linux.CLD_STOPPED si.SetStatus(int32(s.StopSignal())) } case s.Continued(): - si.Code = arch.CLD_CONTINUED + si.Code = linux.CLD_CONTINUED si.SetStatus(int32(linux.SIGCONT)) default: t.Warningf("waitid got incomprehensible wait status %d", s) diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index c6330c21a..647e089d0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -242,9 +242,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } limit := limits.FromContext(t).Get(limits.FileSize).Cur if uint64(size) >= limit { - t.SendSignal(&arch.SignalInfo{ + t.SendSignal(&linux.SignalInfo{ Signo: int32(linux.SIGXFSZ), - Code: arch.SignalInfoUser, + Code: linux.SI_USER, }) return 0, nil, syserror.EFBIG } diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 1f617ca8f..202486a1e 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//pkg/sync:generic_seqatomic", + template = "//pkg/sync/seqatomic:generic_seqatomic", types = { "Value": "Parameters", }, @@ -25,6 +25,8 @@ go_library( "muldiv_arm64.s", "parameters.go", "sampler.go", + "sampler_amd64.go", + "sampler_arm64.go", "sampler_unsafe.go", "seqatomic_parameters_unsafe.go", "tsc_amd64.s", diff --git a/pkg/sentry/time/sampler.go b/pkg/sentry/time/sampler.go index 4ac9c4474..24a47f5d5 100644 --- a/pkg/sentry/time/sampler.go +++ b/pkg/sentry/time/sampler.go @@ -21,13 +21,6 @@ import ( ) const ( - // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. - // It is further refined as syscalls are made. - defaultOverheadCycles = 1 * 1000 - - // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles. - maxOverheadCycles = 100 * defaultOverheadCycles - // maxSampleLoops is the maximum number of times to try to get a clock sample // under the expected overhead. maxSampleLoops = 5 diff --git a/pkg/sentry/time/sampler_amd64.go b/pkg/sentry/time/sampler_amd64.go new file mode 100644 index 000000000..9f1b4b2fb --- /dev/null +++ b/pkg/sentry/time/sampler_amd64.go @@ -0,0 +1,26 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build amd64 + +package time + +const ( + // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. + // It is further refined as syscalls are made. + defaultOverheadCycles = 1 * 1000 + + // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles. + maxOverheadCycles = 100 * defaultOverheadCycles +) diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go new file mode 100644 index 000000000..4c8d33ae4 --- /dev/null +++ b/pkg/sentry/time/sampler_arm64.go @@ -0,0 +1,42 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build arm64 + +package time + +// getCNTFRQ get ARM counter-timer frequency +func getCNTFRQ() TSCValue + +// getDefaultArchOverheadCycles get default OverheadCycles based on +// ARM counter-timer frequency. Usually ARM counter-timer frequency +// is range from 1-50Mhz which is much less than that on x86, so we +// calibrate defaultOverheadCycles for ARM. +func getDefaultArchOverheadCycles() TSCValue { + // estimated the clock frequency on x86 is 1Ghz. + // 1Ghz devided by counter-timer frequency of ARM to get + // frqRatio. defaultOverheadCycles of ARM equals to that on + // x86 devided by frqRatio + cntfrq := getCNTFRQ() + frqRatio := 1000000000 / cntfrq + overheadCycles := (1 * 1000) / frqRatio + return overheadCycles +} + +// defaultOverheadTSC is the default estimated syscall overhead in TSC cycles. +// It is further refined as syscalls are made. +var defaultOverheadCycles = getDefaultArchOverheadCycles() + +// maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles. +var maxOverheadCycles = 100 * defaultOverheadCycles diff --git a/pkg/sentry/time/tsc_arm64.s b/pkg/sentry/time/tsc_arm64.s index da9fa4112..711349fa1 100644 --- a/pkg/sentry/time/tsc_arm64.s +++ b/pkg/sentry/time/tsc_arm64.s @@ -20,3 +20,9 @@ TEXT ·Rdtsc(SB),NOSPLIT,$0-8 WORD $0xd53be040 //MRS CNTVCT_EL0, R0 MOVD R0, ret+0(FP) RET + +TEXT ·getCNTFRQ(SB),NOSPLIT,$0-8 + // Get the virtual counter frequency. + WORD $0xd53be000 //MRS CNTFRQ_EL0, R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD index d8c4d27b9..ea82f4987 100644 --- a/pkg/sentry/vfs/memxattr/BUILD +++ b/pkg/sentry/vfs/memxattr/BUILD @@ -8,6 +8,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go index 638b5d830..9b7953fa3 100644 --- a/pkg/sentry/vfs/memxattr/xattr.go +++ b/pkg/sentry/vfs/memxattr/xattr.go @@ -17,7 +17,10 @@ package memxattr import ( + "strings" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -26,6 +29,9 @@ import ( // SimpleExtendedAttributes implements extended attributes using a map of // names to values. // +// SimpleExtendedAttributes calls vfs.CheckXattrPermissions, so callers are not +// required to do so. +// // +stateify savable type SimpleExtendedAttributes struct { // mu protects the below fields. @@ -34,7 +40,11 @@ type SimpleExtendedAttributes struct { } // GetXattr returns the value at 'name'. -func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) { +func (x *SimpleExtendedAttributes) GetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.GetXattrOptions) (string, error) { + if err := vfs.CheckXattrPermissions(creds, vfs.MayRead, mode, kuid, opts.Name); err != nil { + return "", err + } + x.mu.RLock() value, ok := x.xattrs[opts.Name] x.mu.RUnlock() @@ -50,7 +60,11 @@ func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, } // SetXattr sets 'value' at 'name'. -func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error { +func (x *SimpleExtendedAttributes) SetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.SetXattrOptions) error { + if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, opts.Name); err != nil { + return err + } + x.mu.Lock() defer x.mu.Unlock() if x.xattrs == nil { @@ -73,12 +87,19 @@ func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error { } // ListXattr returns all names in xattrs. -func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) { +func (x *SimpleExtendedAttributes) ListXattr(creds *auth.Credentials, size uint64) ([]string, error) { // Keep track of the size of the buffer needed in listxattr(2) for the list. listSize := 0 x.mu.RLock() names := make([]string, 0, len(x.xattrs)) + haveCap := creds.HasCapability(linux.CAP_SYS_ADMIN) for n := range x.xattrs { + // Hide extended attributes in the "trusted" namespace from + // non-privileged users. This is consistent with Linux's + // fs/xattr.c:simple_xattr_list(). + if !haveCap && strings.HasPrefix(n, linux.XATTR_TRUSTED_PREFIX) { + continue + } names = append(names, n) // Add one byte per null terminator. listSize += len(n) + 1 @@ -91,7 +112,11 @@ func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) { } // RemoveXattr removes the xattr at 'name'. -func (x *SimpleExtendedAttributes) RemoveXattr(name string) error { +func (x *SimpleExtendedAttributes) RemoveXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, name string) error { + if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, name); err != nil { + return err + } + x.mu.Lock() defer x.mu.Unlock() if _, ok := x.xattrs[name]; !ok { diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 82fd382c2..f93da3af1 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -220,7 +220,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr vdDentry := vd.dentry vdDentry.mu.Lock() for { - if vdDentry.dead { + if vd.mount.umounted || vdDentry.dead { vdDentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef(ctx) diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index dfe85f31d..8d563d53a 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -115,14 +115,14 @@ func (a *Action) Get() interface{} { } // String returns Action's string representation. -func (a *Action) String() string { - switch *a { +func (a Action) String() string { + switch a { case LogWarning: return "logWarning" case Panic: return "panic" default: - panic(fmt.Sprintf("Invalid watchdog action: %d", *a)) + panic(fmt.Sprintf("Invalid watchdog action: %d", a)) } } diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD index fd6127b97..367765209 100644 --- a/pkg/shim/BUILD +++ b/pkg/shim/BUILD @@ -6,6 +6,7 @@ go_library( name = "shim", srcs = [ "api.go", + "debug.go", "epoll.go", "options.go", "service.go", diff --git a/pkg/shim/debug.go b/pkg/shim/debug.go new file mode 100644 index 000000000..49f01990e --- /dev/null +++ b/pkg/shim/debug.go @@ -0,0 +1,48 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "os" + "os/signal" + "runtime" + "sync" + "syscall" + + "github.com/containerd/containerd/log" +) + +var once sync.Once + +func setDebugSigHandler() { + once.Do(func() { + dumpCh := make(chan os.Signal, 1) + signal.Notify(dumpCh, syscall.SIGUSR2) + go func() { + buf := make([]byte, 10240) + for range dumpCh { + for { + n := runtime.Stack(buf, true) + if n >= len(buf) { + buf = make([]byte, 2*len(buf)) + continue + } + log.L.Debugf("User requested stack trace:\n%s", buf[:n]) + } + } + }() + log.L.Debugf("For full process dump run: kill -%d %d", syscall.SIGUSR2, os.Getpid()) + }) +} diff --git a/pkg/shim/proc/deleted_state.go b/pkg/shim/proc/deleted_state.go index d9b970c4d..b0bbe4d7e 100644 --- a/pkg/shim/proc/deleted_state.go +++ b/pkg/shim/proc/deleted_state.go @@ -22,28 +22,38 @@ import ( "github.com/containerd/console" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/pkg/process" + runc "github.com/containerd/go-runc" ) type deletedState struct{} -func (*deletedState) Resize(ws console.WinSize) error { - return fmt.Errorf("cannot resize a deleted process.ss") +func (*deletedState) Resize(console.WinSize) error { + return fmt.Errorf("cannot resize a deleted container/process") } -func (*deletedState) Start(ctx context.Context) error { - return fmt.Errorf("cannot start a deleted process.ss") +func (*deletedState) Start(context.Context) error { + return fmt.Errorf("cannot start a deleted container/process") } -func (*deletedState) Delete(ctx context.Context) error { - return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound) +func (*deletedState) Delete(context.Context) error { + return fmt.Errorf("cannot delete a deleted container/process: %w", errdefs.ErrNotFound) } -func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error { - return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound) +func (*deletedState) Kill(_ context.Context, signal uint32, _ bool) error { + return handleStoppedKill(signal) } -func (*deletedState) SetExited(status int) {} +func (*deletedState) SetExited(int) {} -func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { +func (*deletedState) Exec(context.Context, string, *ExecConfig) (process.Process, error) { return nil, fmt.Errorf("cannot exec in a deleted state") } + +func (s *deletedState) State(context.Context) (string, error) { + // There is no "deleted" state, closest one is stopped. + return "stopped", nil +} + +func (s *deletedState) Stats(context.Context, string) (*runc.Stats, error) { + return nil, fmt.Errorf("cannot stat a stopped container/process") +} diff --git a/pkg/shim/proc/exec.go b/pkg/shim/proc/exec.go index e7968d9d5..14df3a778 100644 --- a/pkg/shim/proc/exec.go +++ b/pkg/shim/proc/exec.go @@ -145,16 +145,13 @@ func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error { func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error { internalPid := e.internalPid - if internalPid != 0 { - if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{ - Pid: internalPid, - }); err != nil { - // If this returns error, consider the process has - // already stopped. - // - // TODO: Fix after signal handling is fixed. - return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) - } + if internalPid == 0 { + return nil + } + + opts := runsc.KillOpts{Pid: internalPid} + if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &opts); err != nil { + return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) } return nil } diff --git a/pkg/shim/proc/exec_state.go b/pkg/shim/proc/exec_state.go index 4dcda8b44..04a5d19b4 100644 --- a/pkg/shim/proc/exec_state.go +++ b/pkg/shim/proc/exec_state.go @@ -34,18 +34,21 @@ type execCreatedState struct { p *execProcess } -func (s *execCreatedState) transition(name string) error { - switch name { - case "running": +func (s *execCreatedState) name() string { + return "created" +} + +func (s *execCreatedState) transition(transition stateTransition) { + switch transition { + case running: s.p.execState = &execRunningState{p: s.p} - case "stopped": + case stopped: s.p.execState = &execStoppedState{p: s.p} - case "deleted": + case deleted: s.p.execState = &deletedState{} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil } func (s *execCreatedState) Resize(ws console.WinSize) error { @@ -56,14 +59,16 @@ func (s *execCreatedState) Start(ctx context.Context) error { if err := s.p.start(ctx); err != nil { return err } - return s.transition("running") + s.transition(running) + return nil } func (s *execCreatedState) Delete(ctx context.Context) error { if err := s.p.delete(ctx); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error { @@ -72,35 +77,35 @@ func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error func (s *execCreatedState) SetExited(status int) { s.p.setExited(status) - - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } type execRunningState struct { p *execProcess } -func (s *execRunningState) transition(name string) error { - switch name { - case "stopped": +func (s *execRunningState) name() string { + return "running" +} + +func (s *execRunningState) transition(transition stateTransition) { + switch transition { + case stopped: s.p.execState = &execStoppedState{p: s.p} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil } func (s *execRunningState) Resize(ws console.WinSize) error { return s.p.resize(ws) } -func (s *execRunningState) Start(ctx context.Context) error { +func (s *execRunningState) Start(context.Context) error { return fmt.Errorf("cannot start a running process") } -func (s *execRunningState) Delete(ctx context.Context) error { +func (s *execRunningState) Delete(context.Context) error { return fmt.Errorf("cannot delete a running process") } @@ -110,31 +115,31 @@ func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error func (s *execRunningState) SetExited(status int) { s.p.setExited(status) - - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } type execStoppedState struct { p *execProcess } -func (s *execStoppedState) transition(name string) error { - switch name { - case "deleted": +func (s *execStoppedState) name() string { + return "stopped" +} + +func (s *execStoppedState) transition(transition stateTransition) { + switch transition { + case deleted: s.p.execState = &deletedState{} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil } -func (s *execStoppedState) Resize(ws console.WinSize) error { +func (s *execStoppedState) Resize(console.WinSize) error { return fmt.Errorf("cannot resize a stopped container") } -func (s *execStoppedState) Start(ctx context.Context) error { +func (s *execStoppedState) Start(context.Context) error { return fmt.Errorf("cannot start a stopped process") } @@ -142,13 +147,14 @@ func (s *execStoppedState) Delete(ctx context.Context) error { if err := s.p.delete(ctx); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error { return s.p.kill(ctx, sig, all) } -func (s *execStoppedState) SetExited(status int) { +func (s *execStoppedState) SetExited(int) { // no op } diff --git a/pkg/shim/proc/init.go b/pkg/shim/proc/init.go index 664465e0d..6bf090813 100644 --- a/pkg/shim/proc/init.go +++ b/pkg/shim/proc/init.go @@ -39,6 +39,8 @@ import ( "gvisor.dev/gvisor/pkg/shim/runsc" ) +const statusStopped = "stopped" + // Init represents an initial process for a container. type Init struct { wg sync.WaitGroup @@ -201,10 +203,15 @@ func (p *Init) ExitedAt() time.Time { func (p *Init) Status(ctx context.Context) (string, error) { p.mu.Lock() defer p.mu.Unlock() + + return p.initState.State(ctx) +} + +func (p *Init) state(ctx context.Context) (string, error) { c, err := p.runtime.State(ctx, p.id) if err != nil { if strings.Contains(err.Error(), "does not exist") { - return "stopped", nil + return statusStopped, nil } return "", p.runtimeError(err, "OCI runtime state failed") } @@ -231,10 +238,7 @@ func (p *Init) start(ctx context.Context) error { status, err := p.runtime.Wait(context.Background(), p.id) if err != nil { log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id) - // TODO(random-liu): Handle runsc kill error. - if err := p.killAll(ctx); err != nil { - log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id) - } + p.killAllLocked(ctx) status = internalErrorCode } ExitCh <- Exit{ @@ -255,6 +259,12 @@ func (p *Init) SetExited(status int) { } func (p *Init) setExited(status int) { + if !p.exited.IsZero() { + log.L.Debugf("Status already set to %d, ignoring status: %d", p.status, status) + return + } + + log.L.Debugf("Setting status: %d", status) p.exited = time.Now() p.status = status p.Platform.ShutdownConsole(context.Background(), p.console) @@ -270,15 +280,16 @@ func (p *Init) Delete(ctx context.Context) error { } func (p *Init) delete(ctx context.Context) error { - p.killAll(ctx) + p.killAllLocked(ctx) p.wg.Wait() + err := p.runtime.Delete(ctx, p.id, nil) - // ignore errors if a runtime has already deleted the process - // but we still hold metadata and pipes - // - // this is common during a checkpoint, runc will delete the container state - // after a checkpoint and the container will no longer exist within runc if err != nil { + // ignore errors if a runtime has already deleted the process + // but we still hold metadata and pipes + // + // this is common during a checkpoint, runc will delete the container state + // after a checkpoint and the container will no longer exist within runc if strings.Contains(err.Error(), "does not exist") { err = nil } else { @@ -326,29 +337,24 @@ func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error { return p.initState.Kill(ctx, signal, all) } -func (p *Init) kill(context context.Context, signal uint32, all bool) error { +func (p *Init) kill(ctx context.Context, signal uint32, all bool) error { var ( killErr error backoff = 100 * time.Millisecond ) - timeout := 1 * time.Second - for start := time.Now(); time.Now().Sub(start) < timeout; { - c, err := p.runtime.State(context, p.id) + const timeout = time.Second + for start := time.Now(); time.Since(start) < timeout; { + state, err := p.initState.State(ctx) if err != nil { - if strings.Contains(err.Error(), "does not exist") { - return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) - } return p.runtimeError(err, "OCI runtime state failed") } // For runsc, signal only works when container is running state. // If the container is not in running state, directly return // "no such process" - if p.convertStatus(c.Status) == "stopped" { + if state == statusStopped { return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) } - killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{ - All: all, - }) + killErr = p.runtime.Kill(ctx, p.id, int(signal), &runsc.KillOpts{All: all}) if killErr == nil { return nil } @@ -358,22 +364,18 @@ func (p *Init) kill(context context.Context, signal uint32, all bool) error { return p.runtimeError(killErr, "kill timeout") } -// KillAll kills all processes belonging to the init process. -func (p *Init) KillAll(context context.Context) error { +// KillAll kills all processes belonging to the init process. If +// `runsc kill --all` returns error, assume the container has already stopped. +func (p *Init) KillAll(context context.Context) { p.mu.Lock() defer p.mu.Unlock() - return p.killAll(context) + p.killAllLocked(context) } -func (p *Init) killAll(context context.Context) error { - p.runtime.Kill(context, p.id, int(unix.SIGKILL), &runsc.KillOpts{ - All: true, - }) - // Ignore error handling for `runsc kill --all` for now. - // * If it doesn't return error, it is good; - // * If it returns error, consider the container has already stopped. - // TODO: Fix `runsc kill --all` error handling. - return nil +func (p *Init) killAllLocked(context context.Context) { + if err := p.runtime.Kill(context, p.id, int(unix.SIGKILL), &runsc.KillOpts{All: true}); err != nil { + log.L.Warningf("Ignoring error killing container %q: %v", p.id, err) + } } // Stdin returns the stdin of the process. @@ -396,7 +398,6 @@ func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Pr // exec returns a new exec'd process. func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) { - // process exec request var spec specs.Process if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { return nil, err @@ -420,6 +421,17 @@ func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) { return e, nil } +func (p *Init) Stats(ctx context.Context, id string) (*runc.Stats, error) { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Stats(ctx, id) +} + +func (p *Init) stats(ctx context.Context, id string) (*runc.Stats, error) { + return p.Runtime().Stats(ctx, id) +} + // Stdio returns the stdio of the process. func (p *Init) Stdio() stdio.Stdio { return p.stdio @@ -444,7 +456,7 @@ func (p *Init) runtimeError(rErr error, msg string) error { func (p *Init) convertStatus(status string) string { if status == "created" && !p.Sandbox && p.status == internalErrorCode { // Treat start failure state for non-root container as stopped. - return "stopped" + return statusStopped } return status } diff --git a/pkg/shim/proc/init_state.go b/pkg/shim/proc/init_state.go index 0065fc385..d65020e76 100644 --- a/pkg/shim/proc/init_state.go +++ b/pkg/shim/proc/init_state.go @@ -19,16 +19,39 @@ import ( "context" "fmt" - "github.com/containerd/console" "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/pkg/process" + runc "github.com/containerd/go-runc" + "golang.org/x/sys/unix" ) +type stateTransition int + +const ( + running stateTransition = iota + stopped + deleted +) + +func (s stateTransition) String() string { + switch s { + case running: + return "running" + case stopped: + return "stopped" + case deleted: + return "deleted" + default: + panic(fmt.Sprintf("unknown state: %d", s)) + } +} + type initState interface { - Resize(console.WinSize) error Start(context.Context) error Delete(context.Context) error Exec(context.Context, string, *ExecConfig) (process.Process, error) + State(ctx context.Context) (string, error) + Stats(context.Context, string) (*runc.Stats, error) Kill(context.Context, uint32, bool) error SetExited(int) } @@ -37,22 +60,21 @@ type createdState struct { p *Init } -func (s *createdState) transition(name string) error { - switch name { - case "running": +func (s *createdState) name() string { + return "created" +} + +func (s *createdState) transition(transition stateTransition) { + switch transition { + case running: s.p.initState = &runningState{p: s.p} - case "stopped": - s.p.initState = &stoppedState{p: s.p} - case "deleted": + case stopped: + s.p.initState = &stoppedState{process: s.p} + case deleted: s.p.initState = &deletedState{} default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) } - return nil -} - -func (s *createdState) Resize(ws console.WinSize) error { - return s.p.resize(ws) } func (s *createdState) Start(ctx context.Context) error { @@ -66,20 +88,20 @@ func (s *createdState) Start(ctx context.Context) error { if !s.p.Sandbox { s.p.io.Close() s.p.setExited(internalErrorCode) - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } return err } - return s.transition("running") + s.transition(running) + return nil } func (s *createdState) Delete(ctx context.Context) error { if err := s.p.delete(ctx); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { @@ -88,40 +110,48 @@ func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { func (s *createdState) SetExited(status int) { s.p.setExited(status) - - if err := s.transition("stopped"); err != nil { - panic(err) - } + s.transition(stopped) } func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { return s.p.exec(path, r) } +func (s *createdState) State(ctx context.Context) (string, error) { + state, err := s.p.state(ctx) + if err == nil && state == statusStopped { + s.transition(stopped) + } + return state, err +} + +func (s *createdState) Stats(ctx context.Context, id string) (*runc.Stats, error) { + return s.p.stats(ctx, id) +} + type runningState struct { p *Init } -func (s *runningState) transition(name string) error { - switch name { - case "stopped": - s.p.initState = &stoppedState{p: s.p} - default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) - } - return nil +func (s *runningState) name() string { + return "running" } -func (s *runningState) Resize(ws console.WinSize) error { - return s.p.resize(ws) +func (s *runningState) transition(transition stateTransition) { + switch transition { + case stopped: + s.p.initState = &stoppedState{process: s.p} + default: + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) + } } func (s *runningState) Start(ctx context.Context) error { - return fmt.Errorf("cannot start a running process.ss") + return fmt.Errorf("cannot start a running container") } func (s *runningState) Delete(ctx context.Context) error { - return fmt.Errorf("cannot delete a running process.ss") + return fmt.Errorf("cannot delete a running container") } func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { @@ -130,53 +160,81 @@ func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { func (s *runningState) SetExited(status int) { s.p.setExited(status) + s.transition(stopped) +} + +func (s *runningState) Exec(_ context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(path, r) +} - if err := s.transition("stopped"); err != nil { - panic(err) +func (s *runningState) State(ctx context.Context) (string, error) { + state, err := s.p.state(ctx) + if err == nil && state == "stopped" { + s.transition(stopped) } + return state, err } -func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { - return s.p.exec(path, r) +func (s *runningState) Stats(ctx context.Context, id string) (*runc.Stats, error) { + return s.p.stats(ctx, id) } type stoppedState struct { - p *Init + process *Init } -func (s *stoppedState) transition(name string) error { - switch name { - case "deleted": - s.p.initState = &deletedState{} - default: - return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) - } - return nil +func (s *stoppedState) name() string { + return "stopped" } -func (s *stoppedState) Resize(ws console.WinSize) error { - return fmt.Errorf("cannot resize a stopped container") +func (s *stoppedState) transition(transition stateTransition) { + switch transition { + case deleted: + s.process.initState = &deletedState{} + default: + panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition)) + } } -func (s *stoppedState) Start(ctx context.Context) error { - return fmt.Errorf("cannot start a stopped process.ss") +func (s *stoppedState) Start(context.Context) error { + return fmt.Errorf("cannot start a stopped container") } func (s *stoppedState) Delete(ctx context.Context) error { - if err := s.p.delete(ctx); err != nil { + if err := s.process.delete(ctx); err != nil { return err } - return s.transition("deleted") + s.transition(deleted) + return nil } -func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error { - return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id) +func (s *stoppedState) Kill(_ context.Context, signal uint32, _ bool) error { + return handleStoppedKill(signal) } func (s *stoppedState) SetExited(status int) { - // no op + s.process.setExited(status) } -func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { +func (s *stoppedState) Exec(context.Context, string, *ExecConfig) (process.Process, error) { return nil, fmt.Errorf("cannot exec in a stopped state") } + +func (s *stoppedState) State(context.Context) (string, error) { + return "stopped", nil +} + +func (s *stoppedState) Stats(context.Context, string) (*runc.Stats, error) { + return nil, fmt.Errorf("cannot stat a stopped container") +} + +func handleStoppedKill(signal uint32) error { + switch unix.Signal(signal) { + case unix.SIGTERM, unix.SIGKILL: + // Container is already stopped, so everything inside the container has + // already been killed. + return nil + default: + return errdefs.ToGRPCf(errdefs.ErrNotFound, "process not found") + } +} diff --git a/pkg/shim/proc/proc.go b/pkg/shim/proc/proc.go index edba3fca5..89ad3f505 100644 --- a/pkg/shim/proc/proc.go +++ b/pkg/shim/proc/proc.go @@ -17,23 +17,5 @@ // the sandbox process running the container. package proc -import ( - "fmt" -) - // RunscRoot is the path to the root runsc state directory. const RunscRoot = "/run/containerd/runsc" - -func stateName(v interface{}) string { - switch v.(type) { - case *runningState, *execRunningState: - return "running" - case *createdState, *execCreatedState: - return "created" - case *deletedState: - return "deleted" - case *stoppedState: - return "stopped" - } - panic(fmt.Errorf("invalid state %v", v)) -} diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go index ff0521d73..888cb0bcb 100644 --- a/pkg/shim/runsc/runsc.go +++ b/pkg/shim/runsc/runsc.go @@ -17,6 +17,7 @@ package runsc import ( + "bytes" "context" "encoding/json" "fmt" @@ -73,9 +74,9 @@ type Runsc struct { // List returns all containers created inside the provided runsc root directory. func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { - data, err := cmdOutput(r.command(context, "list", "--format=json"), false) + data, stderr, err := cmdOutput(r.command(context, "list", "--format=json"), false) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %s", err, stderr) } var out []*runc.Container if err := json.Unmarshal(data, &out); err != nil { @@ -86,9 +87,9 @@ func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { // State returns the state for the container provided by id. func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) { - data, err := cmdOutput(r.command(context, "state", id), true) + data, stderr, err := cmdOutput(r.command(context, "state", id), false) if err != nil { - return nil, fmt.Errorf("%s: %s", err, data) + return nil, fmt.Errorf("%w: %s", err, stderr) } var c runc.Container if err := json.Unmarshal(data, &c); err != nil { @@ -142,9 +143,9 @@ func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateO } if cmd.Stdout == nil && cmd.Stderr == nil { - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -168,15 +169,15 @@ func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateO } func (r *Runsc) Pause(context context.Context, id string) error { - if _, err := cmdOutput(r.command(context, "pause", id), true); err != nil { - return fmt.Errorf("unable to pause: %w", err) + if out, _, err := cmdOutput(r.command(context, "pause", id), true); err != nil { + return fmt.Errorf("unable to pause: %w: %s", err, out) } return nil } func (r *Runsc) Resume(context context.Context, id string) error { - if _, err := cmdOutput(r.command(context, "resume", id), true); err != nil { - return fmt.Errorf("unable to resume: %w", err) + if out, _, err := cmdOutput(r.command(context, "resume", id), true); err != nil { + return fmt.Errorf("unable to resume: %w: %s", err, out) } return nil } @@ -189,9 +190,9 @@ func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error { } if cmd.Stdout == nil && cmd.Stderr == nil { - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -221,12 +222,10 @@ type waitResult struct { } // Wait will wait for a running container, and return its exit status. -// -// TODO(random-liu): Add exec process support. func (r *Runsc) Wait(context context.Context, id string) (int, error) { - data, err := cmdOutput(r.command(context, "wait", id), true) + data, stderr, err := cmdOutput(r.command(context, "wait", id), false) if err != nil { - return 0, fmt.Errorf("%s: %s", err, data) + return 0, fmt.Errorf("%w: %s", err, stderr) } var res waitResult if err := json.Unmarshal(data, &res); err != nil { @@ -294,9 +293,9 @@ func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opt opts.Set(cmd) } if cmd.Stdout == nil && cmd.Stderr == nil { - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -391,20 +390,12 @@ func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts // Stats return the stats for a container like cpu, memory, and I/O. func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { cmd := r.command(context, "events", "--stats", id) - rd, err := cmd.StdoutPipe() - if err != nil { - return nil, err - } - ec, err := Monitor.Start(cmd) + data, stderr, err := cmdOutput(cmd, false) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %s", err, stderr) } - defer func() { - rd.Close() - Monitor.Wait(cmd, ec) - }() var e runc.Event - if err := json.NewDecoder(rd).Decode(&e); err != nil { + if err := json.Unmarshal(data, &e); err != nil { log.L.Debugf("Parsing events error: %v", err) return nil, err } @@ -459,9 +450,9 @@ func (r *Runsc) Events(context context.Context, id string, interval time.Duratio // Ps lists all the processes inside the container returning their pids. func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { - data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true) + data, stderr, err := cmdOutput(r.command(context, "ps", "--format", "json", id), false) if err != nil { - return nil, fmt.Errorf("%s: %s", err, data) + return nil, fmt.Errorf("%w: %s", err, stderr) } var pids []int if err := json.Unmarshal(data, &pids); err != nil { @@ -472,9 +463,9 @@ func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { // Top lists all the processes inside the container returning the full ps data. func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) { - data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true) + data, stderr, err := cmdOutput(r.command(context, "ps", "--format", "table", id), false) if err != nil { - return nil, fmt.Errorf("%s: %s", err, data) + return nil, fmt.Errorf("%w: %s", err, stderr) } topResults, err := runc.ParsePSOutput(data) @@ -517,9 +508,9 @@ func (r *Runsc) runOrError(cmd *exec.Cmd) error { } return err } - data, err := cmdOutput(cmd, true) + out, _, err := cmdOutput(cmd, true) if err != nil { - return fmt.Errorf("%s: %s", err, data) + return fmt.Errorf("%w: %s", err, out) } return nil } @@ -540,23 +531,29 @@ func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd { return cmd } -func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) { - b := getBuf() - defer putBuf(b) +func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, []byte, error) { + stdout := getBuf() + defer putBuf(stdout) + cmd.Stdout = stdout + cmd.Stderr = stdout - cmd.Stdout = b - if combined { - cmd.Stderr = b + var stderr *bytes.Buffer + if !combined { + stderr = getBuf() + defer putBuf(stderr) + cmd.Stderr = stderr } ec, err := Monitor.Start(cmd) if err != nil { - return nil, err + return nil, nil, err } status, err := Monitor.Wait(cmd, ec) if err == nil && status != 0 { - err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + err = fmt.Errorf("%q did not terminate sucessfully", cmd.Args[0]) } - - return b.Bytes(), err + if stderr == nil { + return stdout.Bytes(), nil, err + } + return stdout.Bytes(), stderr.Bytes(), err } diff --git a/pkg/shim/service.go b/pkg/shim/service.go index 1f9adcb65..ea9a1ae10 100644 --- a/pkg/shim/service.go +++ b/pkg/shim/service.go @@ -81,8 +81,6 @@ const ( // New returns a new shim service that can be used via GRPC. func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { - log.L.Debugf("service.New, id: %s", id) - var opts shim.Opts if ctxOpts := ctx.Value(shim.OptsKey{}); ctxOpts != nil { opts = ctxOpts.(shim.Opts) @@ -304,8 +302,6 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) // Create creates a new initial process and container with the underlying OCI // runtime. func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*taskAPI.CreateTaskResponse, error) { - log.L.Debugf("Create, id: %s, bundle: %q", r.ID, r.Bundle) - s.mu.Lock() defer s.mu.Unlock() @@ -396,6 +392,9 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*ta log.L.Debugf("stdout: %s", r.Stdout) log.L.Debugf("stderr: %s", r.Stderr) log.L.Debugf("***************************") + if log.L.Logger.IsLevelEnabled(logrus.DebugLevel) { + setDebugSigHandler() + } } // Save state before any action is taken to ensure Cleanup() will have all @@ -506,9 +505,6 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP if err != nil { return nil, err } - if p == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") - } if err := p.Delete(ctx); err != nil { return nil, err } @@ -580,10 +576,12 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. p, err := s.getProcess(r.ExecID) if err != nil { + log.L.Debugf("State failed to find process: %v", err) return nil, err } st, err := p.Status(ctx) if err != nil { + log.L.Debugf("State failed: %v", err) return nil, err } status := task.StatusUnknown @@ -596,7 +594,7 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. status = task.StatusStopped } sio := p.Stdio() - return &taskAPI.StateResponse{ + res := &taskAPI.StateResponse{ ID: p.ID(), Bundle: s.bundle, Pid: uint32(p.Pid()), @@ -607,7 +605,9 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. Terminal: sio.Terminal, ExitStatus: uint32(p.ExitStatus()), ExitedAt: p.ExitedAt(), - }, nil + } + log.L.Debugf("State succeeded, response: %+v", res) + return res, nil } // Pause the container. @@ -646,12 +646,11 @@ func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empt if err != nil { return nil, err } - if p == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") - } if err := p.Kill(ctx, r.Signal, r.All); err != nil { + log.L.Debugf("Kill failed: %v", err) return nil, errdefs.ToGRPC(err) } + log.L.Debugf("Kill succeeded") return empty, nil } @@ -740,7 +739,7 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. log.L.Debugf("Stats error, id: %s: container not created", r.ID) return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - stats, err := s.task.Runtime().Stats(ctx, s.id) + stats, err := s.task.Stats(ctx, s.id) if err != nil { log.L.Debugf("Stats error, id: %s: %v", r.ID, err) return nil, err @@ -821,17 +820,17 @@ func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.Wa p, err := s.getProcess(r.ExecID) if err != nil { + log.L.Debugf("Wait failed to find process: %v", err) return nil, err } - if p == nil { - return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") - } p.Wait() - return &taskAPI.WaitResponse{ + res := &taskAPI.WaitResponse{ ExitStatus: uint32(p.ExitStatus()), ExitedAt: p.ExitedAt(), - }, nil + } + log.L.Debugf("Wait succeeded, response: %+v", res) + return res, nil } func (s *service) processExits(ctx context.Context) { @@ -848,10 +847,7 @@ func (s *service) checkProcesses(ctx context.Context, e proc.Exit) { if ip, ok := p.(*proc.Init); ok { // Ensure all children are killed. log.L.Debugf("Container init process exited, killing all container processes") - if err := ip.KillAll(ctx); err != nil { - log.G(ctx).WithError(err).WithField("id", ip.ID()). - Error("failed to kill init's children") - } + ip.KillAll(ctx) } p.SetExited(e.Status) s.events <- &events.TaskExit{ @@ -909,9 +905,14 @@ func (s *service) forward(ctx context.Context, publisher shim.Publisher) { func (s *service) getProcess(execID string) (process.Process, error) { s.mu.Lock() defer s.mu.Unlock() + if execID == "" { + if s.task == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } return s.task, nil } + p := s.processes[execID] if p == nil { return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 8b3a11c64..73791b456 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -1,5 +1,4 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") package( default_visibility = ["//:sandbox"], @@ -8,45 +7,6 @@ package( exports_files(["LICENSE"]) -go_template( - name = "generic_atomicptr", - srcs = ["generic_atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_atomicptrmap", - srcs = ["generic_atomicptrmap_unsafe.go"], - opt_consts = [ - "ShardOrder", - ], - opt_types = [ - "Hasher", - ], - types = [ - "Key", - "Value", - ], - deps = [ - ":sync", - "//pkg/gohacks", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["generic_seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - "//pkg/gohacks", - ], -) - go_library( name = "sync", srcs = [ diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptr/BUILD index e97553254..a6a7f01ac 100644 --- a/pkg/sync/atomicptrtest/BUILD +++ b/pkg/sync/atomicptr/BUILD @@ -1,14 +1,23 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +go_template( + name = "generic_atomicptr", + srcs = ["generic_atomicptr_unsafe.go"], + types = [ + "Value", + ], + visibility = ["//:sandbox"], +) + go_template_instance( name = "atomicptr_int", out = "atomicptr_int_unsafe.go", package = "atomicptr", suffix = "Int", - template = "//pkg/sync:generic_atomicptr", + template = ":generic_atomicptr", types = { "Value": "int", }, diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptr/atomicptr_test.go index 8fdc5112e..8fdc5112e 100644 --- a/pkg/sync/atomicptrtest/atomicptr_test.go +++ b/pkg/sync/atomicptr/atomicptr_test.go diff --git a/pkg/sync/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go index 82b6df18c..82b6df18c 100644 --- a/pkg/sync/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmap/BUILD index 3f71ae97d..b0e218c79 100644 --- a/pkg/sync/atomicptrmaptest/BUILD +++ b/pkg/sync/atomicptrmap/BUILD @@ -1,17 +1,36 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package( default_visibility = ["//visibility:private"], licenses = ["notice"], ) +go_template( + name = "generic_atomicptrmap", + srcs = ["generic_atomicptrmap_unsafe.go"], + opt_consts = [ + "ShardOrder", + ], + opt_types = [ + "Hasher", + ], + types = [ + "Key", + "Value", + ], + deps = [ + "//pkg/gohacks", + "//pkg/sync", + ], +) + go_template_instance( name = "test_atomicptrmap", out = "test_atomicptrmap_unsafe.go", package = "atomicptrmap", prefix = "test", - template = "//pkg/sync:generic_atomicptrmap", + template = ":generic_atomicptrmap", types = { "Key": "int64", "Value": "testValue", @@ -27,7 +46,7 @@ go_template_instance( package = "atomicptrmap", prefix = "test", suffix = "Sharded", - template = "//pkg/sync:generic_atomicptrmap", + template = ":generic_atomicptrmap", types = { "Key": "int64", "Value": "testValue", diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap.go b/pkg/sync/atomicptrmap/atomicptrmap.go index 867821ce9..867821ce9 100644 --- a/pkg/sync/atomicptrmaptest/atomicptrmap.go +++ b/pkg/sync/atomicptrmap/atomicptrmap.go diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmap/atomicptrmap_test.go index 75a9997ef..75a9997ef 100644 --- a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go +++ b/pkg/sync/atomicptrmap/atomicptrmap_test.go diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go index 3e98cb309..3e98cb309 100644 --- a/pkg/sync/generic_atomicptrmap_unsafe.go +++ b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomic/BUILD index 5f9164117..60f79ab54 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomic/BUILD @@ -1,14 +1,27 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +go_template( + name = "generic_seqatomic", + srcs = ["generic_seqatomic_unsafe.go"], + types = [ + "Value", + ], + visibility = ["//:sandbox"], + deps = [ + ":sync", + "//pkg/gohacks", + ], +) + go_template_instance( name = "seqatomic_int", out = "seqatomic_int_unsafe.go", package = "seqatomic", suffix = "Int", - template = "//pkg/sync:generic_seqatomic", + template = ":generic_seqatomic", types = { "Value": "int", }, diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go index 9578c9c52..9578c9c52 100644 --- a/pkg/sync/generic_seqatomic_unsafe.go +++ b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomic/seqatomic_test.go index 2c4568b07..2c4568b07 100644 --- a/pkg/sync/seqatomictest/seqatomic_test.go +++ b/pkg/sync/seqatomic/seqatomic_test.go diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index 9cc9e3bf2..7b3160309 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -11,7 +11,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/abi/linux", + "//pkg/abi/linux/errno", "//pkg/syserror", "//pkg/tcpip", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 90be24e15..eb44f1254 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -17,7 +17,7 @@ package syserr import ( "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -25,33 +25,33 @@ import ( // Mapping for tcpip.Error types. var ( - ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), linux.EINVAL) - ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), linux.ENODEV) - ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), linux.ENODEV) - ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), linux.ENOPROTOOPT) - ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), linux.EEXIST) - ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), linux.EEXIST) - ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), linux.EINVAL) - ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), linux.EINVAL) - ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), linux.EALREADY) - ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), linux.EAGAIN) - ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE) - ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL) - ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE) - ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), linux.NOERRNO) - ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT) - ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE) - ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS) - ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), linux.EDESTADDRREQ) - ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), linux.EOPNOTSUPP) - ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), linux.ENOTTY) - ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), linux.ENOENT) - ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), linux.EINVAL) - ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), linux.EACCES) - ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM) - ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT) - ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL) - ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL) + ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), errno.EINVAL) + ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), errno.ENODEV) + ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), errno.ENODEV) + ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), errno.ENOPROTOOPT) + ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), errno.EEXIST) + ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), errno.EEXIST) + ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), errno.EINVAL) + ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), errno.EINVAL) + ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), errno.EALREADY) + ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), errno.EAGAIN) + ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), errno.EADDRINUSE) + ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), errno.EADDRNOTAVAIL) + ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), errno.EPIPE) + ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), errno.NOERRNO) + ErrTimeout = New((&tcpip.ErrTimeout{}).String(), errno.ETIMEDOUT) + ErrAborted = New((&tcpip.ErrAborted{}).String(), errno.EPIPE) + ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), errno.EINPROGRESS) + ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), errno.EDESTADDRREQ) + ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), errno.EOPNOTSUPP) + ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), errno.ENOTTY) + ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), errno.ENOENT) + ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), errno.EINVAL) + ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), errno.EACCES) + ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), errno.EPERM) + ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), errno.EFAULT) + ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), errno.EINVAL) + ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), errno.EINVAL) ) // TranslateNetstackError converts an error from the tcpip package to a sentry diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index d70521f32..fb77ac8bd 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -21,7 +21,7 @@ import ( "fmt" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/syserror" ) @@ -31,26 +31,25 @@ type Error struct { message string // noTranslation indicates that this Error cannot be translated to a - // linux.Errno. + // errno.Errno. noTranslation bool - // errno is the linux.Errno this Error should be translated to. - errno linux.Errno + // errno is the errno.Errno this Error should be translated to. + errno errno.Errno } // New creates a new Error and adds a translation for it. // // New must only be called at init. -func New(message string, linuxTranslation linux.Errno) *Error { +func New(message string, linuxTranslation errno.Errno) *Error { err := &Error{message: message, errno: linuxTranslation} // TODO(b/34162363): Remove this. - errno := linuxTranslation - if errno < 0 || int(errno) >= len(linuxBackwardsTranslations) { - panic(fmt.Sprint("invalid errno: ", errno)) + if int(err.errno) >= len(linuxBackwardsTranslations) { + panic(fmt.Sprint("invalid errno: ", err.errno)) } - e := error(unix.Errno(errno)) + e := error(unix.Errno(err.errno)) // syserror.ErrWouldBlock gets translated to syserror.EWOULDBLOCK and // enables proper blocking semantics. This should temporary address the // class of blocking bugs that keep popping up with the current state of @@ -58,7 +57,7 @@ func New(message string, linuxTranslation linux.Errno) *Error { if e == syserror.EWOULDBLOCK { e = syserror.ErrWouldBlock } - linuxBackwardsTranslations[errno] = linuxBackwardsTranslation{err: e, ok: true} + linuxBackwardsTranslations[err.errno] = linuxBackwardsTranslation{err: e, ok: true} return err } @@ -69,7 +68,7 @@ func New(message string, linuxTranslation linux.Errno) *Error { // NewDynamic should only be used sparingly and not be used for static error // messages. Errors with static error messages should be declared with New as // global variables. -func NewDynamic(message string, linuxTranslation linux.Errno) *Error { +func NewDynamic(message string, linuxTranslation errno.Errno) *Error { return &Error{message: message, errno: linuxTranslation} } @@ -82,7 +81,7 @@ func NewWithoutTranslation(message string) *Error { return &Error{message: message, noTranslation: true} } -func newWithHost(message string, linuxTranslation linux.Errno, hostErrno unix.Errno) *Error { +func newWithHost(message string, linuxTranslation errno.Errno, hostErrno unix.Errno) *Error { e := New(message, linuxTranslation) addLinuxHostTranslation(hostErrno, e) return e @@ -114,19 +113,19 @@ func (e *Error) ToError() error { if e.noTranslation { panic(fmt.Sprintf("error %q does not support translation", e.message)) } - errno := int(e.errno) - if errno == linux.NOERRNO { + err := int(e.errno) + if err == errno.NOERRNO { return nil } - if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok { - panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno)) + if err >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[err].ok { + panic(fmt.Sprintf("unknown error %q (%d)", e.message, err)) } - return linuxBackwardsTranslations[errno].err + return linuxBackwardsTranslations[err].err } // ToLinux converts the Error to a Linux ABI error that can be returned to the // application. -func (e *Error) ToLinux() linux.Errno { +func (e *Error) ToLinux() errno.Errno { if e.noTranslation { panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message)) } @@ -138,137 +137,137 @@ func (e *Error) ToLinux() linux.Errno { // Some of the errors should be replaced with package specific errors and // others should be removed entirely. var ( - ErrNotPermitted = newWithHost("operation not permitted", linux.EPERM, unix.EPERM) - ErrNoFileOrDir = newWithHost("no such file or directory", linux.ENOENT, unix.ENOENT) - ErrNoProcess = newWithHost("no such process", linux.ESRCH, unix.ESRCH) - ErrInterrupted = newWithHost("interrupted system call", linux.EINTR, unix.EINTR) - ErrIO = newWithHost("I/O error", linux.EIO, unix.EIO) - ErrDeviceOrAddress = newWithHost("no such device or address", linux.ENXIO, unix.ENXIO) - ErrTooManyArgs = newWithHost("argument list too long", linux.E2BIG, unix.E2BIG) - ErrEcec = newWithHost("exec format error", linux.ENOEXEC, unix.ENOEXEC) - ErrBadFD = newWithHost("bad file number", linux.EBADF, unix.EBADF) - ErrNoChild = newWithHost("no child processes", linux.ECHILD, unix.ECHILD) - ErrTryAgain = newWithHost("try again", linux.EAGAIN, unix.EAGAIN) - ErrNoMemory = newWithHost("out of memory", linux.ENOMEM, unix.ENOMEM) - ErrPermissionDenied = newWithHost("permission denied", linux.EACCES, unix.EACCES) - ErrBadAddress = newWithHost("bad address", linux.EFAULT, unix.EFAULT) - ErrNotBlockDevice = newWithHost("block device required", linux.ENOTBLK, unix.ENOTBLK) - ErrBusy = newWithHost("device or resource busy", linux.EBUSY, unix.EBUSY) - ErrExists = newWithHost("file exists", linux.EEXIST, unix.EEXIST) - ErrCrossDeviceLink = newWithHost("cross-device link", linux.EXDEV, unix.EXDEV) - ErrNoDevice = newWithHost("no such device", linux.ENODEV, unix.ENODEV) - ErrNotDir = newWithHost("not a directory", linux.ENOTDIR, unix.ENOTDIR) - ErrIsDir = newWithHost("is a directory", linux.EISDIR, unix.EISDIR) - ErrInvalidArgument = newWithHost("invalid argument", linux.EINVAL, unix.EINVAL) - ErrFileTableOverflow = newWithHost("file table overflow", linux.ENFILE, unix.ENFILE) - ErrTooManyOpenFiles = newWithHost("too many open files", linux.EMFILE, unix.EMFILE) - ErrNotTTY = newWithHost("not a typewriter", linux.ENOTTY, unix.ENOTTY) - ErrTestFileBusy = newWithHost("text file busy", linux.ETXTBSY, unix.ETXTBSY) - ErrFileTooBig = newWithHost("file too large", linux.EFBIG, unix.EFBIG) - ErrNoSpace = newWithHost("no space left on device", linux.ENOSPC, unix.ENOSPC) - ErrIllegalSeek = newWithHost("illegal seek", linux.ESPIPE, unix.ESPIPE) - ErrReadOnlyFS = newWithHost("read-only file system", linux.EROFS, unix.EROFS) - ErrTooManyLinks = newWithHost("too many links", linux.EMLINK, unix.EMLINK) - ErrBrokenPipe = newWithHost("broken pipe", linux.EPIPE, unix.EPIPE) - ErrDomain = newWithHost("math argument out of domain of func", linux.EDOM, unix.EDOM) - ErrRange = newWithHost("math result not representable", linux.ERANGE, unix.ERANGE) - ErrDeadlock = newWithHost("resource deadlock would occur", linux.EDEADLOCK, unix.EDEADLOCK) - ErrNameTooLong = newWithHost("file name too long", linux.ENAMETOOLONG, unix.ENAMETOOLONG) - ErrNoLocksAvailable = newWithHost("no record locks available", linux.ENOLCK, unix.ENOLCK) - ErrInvalidSyscall = newWithHost("invalid system call number", linux.ENOSYS, unix.ENOSYS) - ErrDirNotEmpty = newWithHost("directory not empty", linux.ENOTEMPTY, unix.ENOTEMPTY) - ErrLinkLoop = newWithHost("too many symbolic links encountered", linux.ELOOP, unix.ELOOP) - ErrNoMessage = newWithHost("no message of desired type", linux.ENOMSG, unix.ENOMSG) - ErrIdentifierRemoved = newWithHost("identifier removed", linux.EIDRM, unix.EIDRM) - ErrChannelOutOfRange = newWithHost("channel number out of range", linux.ECHRNG, unix.ECHRNG) - ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", linux.EL2NSYNC, unix.EL2NSYNC) - ErrLevelThreeHalted = newWithHost("level 3 halted", linux.EL3HLT, unix.EL3HLT) - ErrLevelThreeReset = newWithHost("level 3 reset", linux.EL3RST, unix.EL3RST) - ErrLinkNumberOutOfRange = newWithHost("link number out of range", linux.ELNRNG, unix.ELNRNG) - ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", linux.EUNATCH, unix.EUNATCH) - ErrNoCSIAvailable = newWithHost("no CSI structure available", linux.ENOCSI, unix.ENOCSI) - ErrLevelTwoHalted = newWithHost("level 2 halted", linux.EL2HLT, unix.EL2HLT) - ErrInvalidExchange = newWithHost("invalid exchange", linux.EBADE, unix.EBADE) - ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", linux.EBADR, unix.EBADR) - ErrExchangeFull = newWithHost("exchange full", linux.EXFULL, unix.EXFULL) - ErrNoAnode = newWithHost("no anode", linux.ENOANO, unix.ENOANO) - ErrInvalidRequestCode = newWithHost("invalid request code", linux.EBADRQC, unix.EBADRQC) - ErrInvalidSlot = newWithHost("invalid slot", linux.EBADSLT, unix.EBADSLT) - ErrBadFontFile = newWithHost("bad font file format", linux.EBFONT, unix.EBFONT) - ErrNotStream = newWithHost("device not a stream", linux.ENOSTR, unix.ENOSTR) - ErrNoDataAvailable = newWithHost("no data available", linux.ENODATA, unix.ENODATA) - ErrTimerExpired = newWithHost("timer expired", linux.ETIME, unix.ETIME) - ErrStreamsResourceDepleted = newWithHost("out of streams resources", linux.ENOSR, unix.ENOSR) - ErrMachineNotOnNetwork = newWithHost("machine is not on the network", linux.ENONET, unix.ENONET) - ErrPackageNotInstalled = newWithHost("package not installed", linux.ENOPKG, unix.ENOPKG) - ErrIsRemote = newWithHost("object is remote", linux.EREMOTE, unix.EREMOTE) - ErrNoLink = newWithHost("link has been severed", linux.ENOLINK, unix.ENOLINK) - ErrAdvertise = newWithHost("advertise error", linux.EADV, unix.EADV) - ErrSRMount = newWithHost("srmount error", linux.ESRMNT, unix.ESRMNT) - ErrSendCommunication = newWithHost("communication error on send", linux.ECOMM, unix.ECOMM) - ErrProtocol = newWithHost("protocol error", linux.EPROTO, unix.EPROTO) - ErrMultihopAttempted = newWithHost("multihop attempted", linux.EMULTIHOP, unix.EMULTIHOP) - ErrRFS = newWithHost("RFS specific error", linux.EDOTDOT, unix.EDOTDOT) - ErrInvalidDataMessage = newWithHost("not a data message", linux.EBADMSG, unix.EBADMSG) - ErrOverflow = newWithHost("value too large for defined data type", linux.EOVERFLOW, unix.EOVERFLOW) - ErrNetworkNameNotUnique = newWithHost("name not unique on network", linux.ENOTUNIQ, unix.ENOTUNIQ) - ErrFDInBadState = newWithHost("file descriptor in bad state", linux.EBADFD, unix.EBADFD) - ErrRemoteAddressChanged = newWithHost("remote address changed", linux.EREMCHG, unix.EREMCHG) - ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", linux.ELIBACC, unix.ELIBACC) - ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", linux.ELIBBAD, unix.ELIBBAD) - ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", linux.ELIBSCN, unix.ELIBSCN) - ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", linux.ELIBMAX, unix.ELIBMAX) - ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", linux.ELIBEXEC, unix.ELIBEXEC) - ErrIllegalByteSequence = newWithHost("illegal byte sequence", linux.EILSEQ, unix.EILSEQ) - ErrShouldRestart = newWithHost("interrupted system call should be restarted", linux.ERESTART, unix.ERESTART) - ErrStreamPipe = newWithHost("streams pipe error", linux.ESTRPIPE, unix.ESTRPIPE) - ErrTooManyUsers = newWithHost("too many users", linux.EUSERS, unix.EUSERS) - ErrNotASocket = newWithHost("socket operation on non-socket", linux.ENOTSOCK, unix.ENOTSOCK) - ErrDestinationAddressRequired = newWithHost("destination address required", linux.EDESTADDRREQ, unix.EDESTADDRREQ) - ErrMessageTooLong = newWithHost("message too long", linux.EMSGSIZE, unix.EMSGSIZE) - ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", linux.EPROTOTYPE, unix.EPROTOTYPE) - ErrProtocolNotAvailable = newWithHost("protocol not available", linux.ENOPROTOOPT, unix.ENOPROTOOPT) - ErrProtocolNotSupported = newWithHost("protocol not supported", linux.EPROTONOSUPPORT, unix.EPROTONOSUPPORT) - ErrSocketNotSupported = newWithHost("socket type not supported", linux.ESOCKTNOSUPPORT, unix.ESOCKTNOSUPPORT) - ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", linux.EOPNOTSUPP, unix.EOPNOTSUPP) - ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", linux.EPFNOSUPPORT, unix.EPFNOSUPPORT) - ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", linux.EAFNOSUPPORT, unix.EAFNOSUPPORT) - ErrAddressInUse = newWithHost("address already in use", linux.EADDRINUSE, unix.EADDRINUSE) - ErrAddressNotAvailable = newWithHost("cannot assign requested address", linux.EADDRNOTAVAIL, unix.EADDRNOTAVAIL) - ErrNetworkDown = newWithHost("network is down", linux.ENETDOWN, unix.ENETDOWN) - ErrNetworkUnreachable = newWithHost("network is unreachable", linux.ENETUNREACH, unix.ENETUNREACH) - ErrNetworkReset = newWithHost("network dropped connection because of reset", linux.ENETRESET, unix.ENETRESET) - ErrConnectionAborted = newWithHost("software caused connection abort", linux.ECONNABORTED, unix.ECONNABORTED) - ErrConnectionReset = newWithHost("connection reset by peer", linux.ECONNRESET, unix.ECONNRESET) - ErrNoBufferSpace = newWithHost("no buffer space available", linux.ENOBUFS, unix.ENOBUFS) - ErrAlreadyConnected = newWithHost("transport endpoint is already connected", linux.EISCONN, unix.EISCONN) - ErrNotConnected = newWithHost("transport endpoint is not connected", linux.ENOTCONN, unix.ENOTCONN) - ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", linux.ESHUTDOWN, unix.ESHUTDOWN) - ErrTooManyRefs = newWithHost("too many references: cannot splice", linux.ETOOMANYREFS, unix.ETOOMANYREFS) - ErrTimedOut = newWithHost("connection timed out", linux.ETIMEDOUT, unix.ETIMEDOUT) - ErrConnectionRefused = newWithHost("connection refused", linux.ECONNREFUSED, unix.ECONNREFUSED) - ErrHostDown = newWithHost("host is down", linux.EHOSTDOWN, unix.EHOSTDOWN) - ErrNoRoute = newWithHost("no route to host", linux.EHOSTUNREACH, unix.EHOSTUNREACH) - ErrAlreadyInProgress = newWithHost("operation already in progress", linux.EALREADY, unix.EALREADY) - ErrInProgress = newWithHost("operation now in progress", linux.EINPROGRESS, unix.EINPROGRESS) - ErrStaleFileHandle = newWithHost("stale file handle", linux.ESTALE, unix.ESTALE) - ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", linux.EUCLEAN, unix.EUCLEAN) - ErrIsNamedFile = newWithHost("is a named type file", linux.ENOTNAM, unix.ENOTNAM) - ErrRemoteIO = newWithHost("remote I/O error", linux.EREMOTEIO, unix.EREMOTEIO) - ErrQuotaExceeded = newWithHost("quota exceeded", linux.EDQUOT, unix.EDQUOT) - ErrNoMedium = newWithHost("no medium found", linux.ENOMEDIUM, unix.ENOMEDIUM) - ErrWrongMediumType = newWithHost("wrong medium type", linux.EMEDIUMTYPE, unix.EMEDIUMTYPE) - ErrCanceled = newWithHost("operation canceled", linux.ECANCELED, unix.ECANCELED) - ErrNoKey = newWithHost("required key not available", linux.ENOKEY, unix.ENOKEY) - ErrKeyExpired = newWithHost("key has expired", linux.EKEYEXPIRED, unix.EKEYEXPIRED) - ErrKeyRevoked = newWithHost("key has been revoked", linux.EKEYREVOKED, unix.EKEYREVOKED) - ErrKeyRejected = newWithHost("key was rejected by service", linux.EKEYREJECTED, unix.EKEYREJECTED) - ErrOwnerDied = newWithHost("owner died", linux.EOWNERDEAD, unix.EOWNERDEAD) - ErrNotRecoverable = newWithHost("state not recoverable", linux.ENOTRECOVERABLE, unix.ENOTRECOVERABLE) + ErrNotPermitted = newWithHost("operation not permitted", errno.EPERM, unix.EPERM) + ErrNoFileOrDir = newWithHost("no such file or directory", errno.ENOENT, unix.ENOENT) + ErrNoProcess = newWithHost("no such process", errno.ESRCH, unix.ESRCH) + ErrInterrupted = newWithHost("interrupted system call", errno.EINTR, unix.EINTR) + ErrIO = newWithHost("I/O error", errno.EIO, unix.EIO) + ErrDeviceOrAddress = newWithHost("no such device or address", errno.ENXIO, unix.ENXIO) + ErrTooManyArgs = newWithHost("argument list too long", errno.E2BIG, unix.E2BIG) + ErrEcec = newWithHost("exec format error", errno.ENOEXEC, unix.ENOEXEC) + ErrBadFD = newWithHost("bad file number", errno.EBADF, unix.EBADF) + ErrNoChild = newWithHost("no child processes", errno.ECHILD, unix.ECHILD) + ErrTryAgain = newWithHost("try again", errno.EAGAIN, unix.EAGAIN) + ErrNoMemory = newWithHost("out of memory", errno.ENOMEM, unix.ENOMEM) + ErrPermissionDenied = newWithHost("permission denied", errno.EACCES, unix.EACCES) + ErrBadAddress = newWithHost("bad address", errno.EFAULT, unix.EFAULT) + ErrNotBlockDevice = newWithHost("block device required", errno.ENOTBLK, unix.ENOTBLK) + ErrBusy = newWithHost("device or resource busy", errno.EBUSY, unix.EBUSY) + ErrExists = newWithHost("file exists", errno.EEXIST, unix.EEXIST) + ErrCrossDeviceLink = newWithHost("cross-device link", errno.EXDEV, unix.EXDEV) + ErrNoDevice = newWithHost("no such device", errno.ENODEV, unix.ENODEV) + ErrNotDir = newWithHost("not a directory", errno.ENOTDIR, unix.ENOTDIR) + ErrIsDir = newWithHost("is a directory", errno.EISDIR, unix.EISDIR) + ErrInvalidArgument = newWithHost("invalid argument", errno.EINVAL, unix.EINVAL) + ErrFileTableOverflow = newWithHost("file table overflow", errno.ENFILE, unix.ENFILE) + ErrTooManyOpenFiles = newWithHost("too many open files", errno.EMFILE, unix.EMFILE) + ErrNotTTY = newWithHost("not a typewriter", errno.ENOTTY, unix.ENOTTY) + ErrTestFileBusy = newWithHost("text file busy", errno.ETXTBSY, unix.ETXTBSY) + ErrFileTooBig = newWithHost("file too large", errno.EFBIG, unix.EFBIG) + ErrNoSpace = newWithHost("no space left on device", errno.ENOSPC, unix.ENOSPC) + ErrIllegalSeek = newWithHost("illegal seek", errno.ESPIPE, unix.ESPIPE) + ErrReadOnlyFS = newWithHost("read-only file system", errno.EROFS, unix.EROFS) + ErrTooManyLinks = newWithHost("too many links", errno.EMLINK, unix.EMLINK) + ErrBrokenPipe = newWithHost("broken pipe", errno.EPIPE, unix.EPIPE) + ErrDomain = newWithHost("math argument out of domain of func", errno.EDOM, unix.EDOM) + ErrRange = newWithHost("math result not representable", errno.ERANGE, unix.ERANGE) + ErrDeadlock = newWithHost("resource deadlock would occur", errno.EDEADLOCK, unix.EDEADLOCK) + ErrNameTooLong = newWithHost("file name too long", errno.ENAMETOOLONG, unix.ENAMETOOLONG) + ErrNoLocksAvailable = newWithHost("no record locks available", errno.ENOLCK, unix.ENOLCK) + ErrInvalidSyscall = newWithHost("invalid system call number", errno.ENOSYS, unix.ENOSYS) + ErrDirNotEmpty = newWithHost("directory not empty", errno.ENOTEMPTY, unix.ENOTEMPTY) + ErrLinkLoop = newWithHost("too many symbolic links encountered", errno.ELOOP, unix.ELOOP) + ErrNoMessage = newWithHost("no message of desired type", errno.ENOMSG, unix.ENOMSG) + ErrIdentifierRemoved = newWithHost("identifier removed", errno.EIDRM, unix.EIDRM) + ErrChannelOutOfRange = newWithHost("channel number out of range", errno.ECHRNG, unix.ECHRNG) + ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", errno.EL2NSYNC, unix.EL2NSYNC) + ErrLevelThreeHalted = newWithHost("level 3 halted", errno.EL3HLT, unix.EL3HLT) + ErrLevelThreeReset = newWithHost("level 3 reset", errno.EL3RST, unix.EL3RST) + ErrLinkNumberOutOfRange = newWithHost("link number out of range", errno.ELNRNG, unix.ELNRNG) + ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", errno.EUNATCH, unix.EUNATCH) + ErrNoCSIAvailable = newWithHost("no CSI structure available", errno.ENOCSI, unix.ENOCSI) + ErrLevelTwoHalted = newWithHost("level 2 halted", errno.EL2HLT, unix.EL2HLT) + ErrInvalidExchange = newWithHost("invalid exchange", errno.EBADE, unix.EBADE) + ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", errno.EBADR, unix.EBADR) + ErrExchangeFull = newWithHost("exchange full", errno.EXFULL, unix.EXFULL) + ErrNoAnode = newWithHost("no anode", errno.ENOANO, unix.ENOANO) + ErrInvalidRequestCode = newWithHost("invalid request code", errno.EBADRQC, unix.EBADRQC) + ErrInvalidSlot = newWithHost("invalid slot", errno.EBADSLT, unix.EBADSLT) + ErrBadFontFile = newWithHost("bad font file format", errno.EBFONT, unix.EBFONT) + ErrNotStream = newWithHost("device not a stream", errno.ENOSTR, unix.ENOSTR) + ErrNoDataAvailable = newWithHost("no data available", errno.ENODATA, unix.ENODATA) + ErrTimerExpired = newWithHost("timer expired", errno.ETIME, unix.ETIME) + ErrStreamsResourceDepleted = newWithHost("out of streams resources", errno.ENOSR, unix.ENOSR) + ErrMachineNotOnNetwork = newWithHost("machine is not on the network", errno.ENONET, unix.ENONET) + ErrPackageNotInstalled = newWithHost("package not installed", errno.ENOPKG, unix.ENOPKG) + ErrIsRemote = newWithHost("object is remote", errno.EREMOTE, unix.EREMOTE) + ErrNoLink = newWithHost("link has been severed", errno.ENOLINK, unix.ENOLINK) + ErrAdvertise = newWithHost("advertise error", errno.EADV, unix.EADV) + ErrSRMount = newWithHost("srmount error", errno.ESRMNT, unix.ESRMNT) + ErrSendCommunication = newWithHost("communication error on send", errno.ECOMM, unix.ECOMM) + ErrProtocol = newWithHost("protocol error", errno.EPROTO, unix.EPROTO) + ErrMultihopAttempted = newWithHost("multihop attempted", errno.EMULTIHOP, unix.EMULTIHOP) + ErrRFS = newWithHost("RFS specific error", errno.EDOTDOT, unix.EDOTDOT) + ErrInvalidDataMessage = newWithHost("not a data message", errno.EBADMSG, unix.EBADMSG) + ErrOverflow = newWithHost("value too large for defined data type", errno.EOVERFLOW, unix.EOVERFLOW) + ErrNetworkNameNotUnique = newWithHost("name not unique on network", errno.ENOTUNIQ, unix.ENOTUNIQ) + ErrFDInBadState = newWithHost("file descriptor in bad state", errno.EBADFD, unix.EBADFD) + ErrRemoteAddressChanged = newWithHost("remote address changed", errno.EREMCHG, unix.EREMCHG) + ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", errno.ELIBACC, unix.ELIBACC) + ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", errno.ELIBBAD, unix.ELIBBAD) + ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", errno.ELIBSCN, unix.ELIBSCN) + ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", errno.ELIBMAX, unix.ELIBMAX) + ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", errno.ELIBEXEC, unix.ELIBEXEC) + ErrIllegalByteSequence = newWithHost("illegal byte sequence", errno.EILSEQ, unix.EILSEQ) + ErrShouldRestart = newWithHost("interrupted system call should be restarted", errno.ERESTART, unix.ERESTART) + ErrStreamPipe = newWithHost("streams pipe error", errno.ESTRPIPE, unix.ESTRPIPE) + ErrTooManyUsers = newWithHost("too many users", errno.EUSERS, unix.EUSERS) + ErrNotASocket = newWithHost("socket operation on non-socket", errno.ENOTSOCK, unix.ENOTSOCK) + ErrDestinationAddressRequired = newWithHost("destination address required", errno.EDESTADDRREQ, unix.EDESTADDRREQ) + ErrMessageTooLong = newWithHost("message too long", errno.EMSGSIZE, unix.EMSGSIZE) + ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", errno.EPROTOTYPE, unix.EPROTOTYPE) + ErrProtocolNotAvailable = newWithHost("protocol not available", errno.ENOPROTOOPT, unix.ENOPROTOOPT) + ErrProtocolNotSupported = newWithHost("protocol not supported", errno.EPROTONOSUPPORT, unix.EPROTONOSUPPORT) + ErrSocketNotSupported = newWithHost("socket type not supported", errno.ESOCKTNOSUPPORT, unix.ESOCKTNOSUPPORT) + ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", errno.EOPNOTSUPP, unix.EOPNOTSUPP) + ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", errno.EPFNOSUPPORT, unix.EPFNOSUPPORT) + ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", errno.EAFNOSUPPORT, unix.EAFNOSUPPORT) + ErrAddressInUse = newWithHost("address already in use", errno.EADDRINUSE, unix.EADDRINUSE) + ErrAddressNotAvailable = newWithHost("cannot assign requested address", errno.EADDRNOTAVAIL, unix.EADDRNOTAVAIL) + ErrNetworkDown = newWithHost("network is down", errno.ENETDOWN, unix.ENETDOWN) + ErrNetworkUnreachable = newWithHost("network is unreachable", errno.ENETUNREACH, unix.ENETUNREACH) + ErrNetworkReset = newWithHost("network dropped connection because of reset", errno.ENETRESET, unix.ENETRESET) + ErrConnectionAborted = newWithHost("software caused connection abort", errno.ECONNABORTED, unix.ECONNABORTED) + ErrConnectionReset = newWithHost("connection reset by peer", errno.ECONNRESET, unix.ECONNRESET) + ErrNoBufferSpace = newWithHost("no buffer space available", errno.ENOBUFS, unix.ENOBUFS) + ErrAlreadyConnected = newWithHost("transport endpoint is already connected", errno.EISCONN, unix.EISCONN) + ErrNotConnected = newWithHost("transport endpoint is not connected", errno.ENOTCONN, unix.ENOTCONN) + ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", errno.ESHUTDOWN, unix.ESHUTDOWN) + ErrTooManyRefs = newWithHost("too many references: cannot splice", errno.ETOOMANYREFS, unix.ETOOMANYREFS) + ErrTimedOut = newWithHost("connection timed out", errno.ETIMEDOUT, unix.ETIMEDOUT) + ErrConnectionRefused = newWithHost("connection refused", errno.ECONNREFUSED, unix.ECONNREFUSED) + ErrHostDown = newWithHost("host is down", errno.EHOSTDOWN, unix.EHOSTDOWN) + ErrNoRoute = newWithHost("no route to host", errno.EHOSTUNREACH, unix.EHOSTUNREACH) + ErrAlreadyInProgress = newWithHost("operation already in progress", errno.EALREADY, unix.EALREADY) + ErrInProgress = newWithHost("operation now in progress", errno.EINPROGRESS, unix.EINPROGRESS) + ErrStaleFileHandle = newWithHost("stale file handle", errno.ESTALE, unix.ESTALE) + ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", errno.EUCLEAN, unix.EUCLEAN) + ErrIsNamedFile = newWithHost("is a named type file", errno.ENOTNAM, unix.ENOTNAM) + ErrRemoteIO = newWithHost("remote I/O error", errno.EREMOTEIO, unix.EREMOTEIO) + ErrQuotaExceeded = newWithHost("quota exceeded", errno.EDQUOT, unix.EDQUOT) + ErrNoMedium = newWithHost("no medium found", errno.ENOMEDIUM, unix.ENOMEDIUM) + ErrWrongMediumType = newWithHost("wrong medium type", errno.EMEDIUMTYPE, unix.EMEDIUMTYPE) + ErrCanceled = newWithHost("operation canceled", errno.ECANCELED, unix.ECANCELED) + ErrNoKey = newWithHost("required key not available", errno.ENOKEY, unix.ENOKEY) + ErrKeyExpired = newWithHost("key has expired", errno.EKEYEXPIRED, unix.EKEYEXPIRED) + ErrKeyRevoked = newWithHost("key has been revoked", errno.EKEYREVOKED, unix.EKEYREVOKED) + ErrKeyRejected = newWithHost("key was rejected by service", errno.EKEYREJECTED, unix.EKEYREJECTED) + ErrOwnerDied = newWithHost("owner died", errno.EOWNERDEAD, unix.EOWNERDEAD) + ErrNotRecoverable = newWithHost("state not recoverable", errno.ENOTRECOVERABLE, unix.ENOTRECOVERABLE) // ErrWouldBlock translates to EWOULDBLOCK which is the same as EAGAIN // on Linux. - ErrWouldBlock = New("operation would block", linux.EWOULDBLOCK) + ErrWouldBlock = New("operation would block", errno.EWOULDBLOCK) ) // FromError converts a generic error to an *Error. diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index d6cad3a94..b1f39e6e6 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -148,15 +148,10 @@ const ( // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. lengthByteUnits = 8 -) -var ( // NDPInfiniteLifetime is a value that represents infinity for the // 4-byte lifetime fields found in various NDP options. Its value is // (2^32 - 1)s = 4294967295s. - // - // This is a variable instead of a constant so that tests can change - // this value to a smaller value. It should only be modified by tests. NDPInfiniteLifetime = time.Second * math.MaxUint32 ) diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index bf7610863..7e2f0c797 100644 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -19,12 +19,72 @@ import ( "time" ) +// NDPRoutePreference is the preference values for default routers or +// more-specific routes. +// +// As per RFC 4191 section 2.1, +// +// Default router preferences and preferences for more-specific routes +// are encoded the same way. +// +// Preference values are encoded as a two-bit signed integer, as +// follows: +// +// 01 High +// 00 Medium (default) +// 11 Low +// 10 Reserved - MUST NOT be sent +// +// Note that implementations can treat the value as a two-bit signed +// integer. +// +// Having just three values reinforces that they are not metrics and +// more values do not appear to be necessary for reasonable scenarios. +type NDPRoutePreference uint8 + +const ( + // HighRoutePreference indicates a high preference, as per + // RFC 4191 section 2.1. + HighRoutePreference NDPRoutePreference = 0b01 + + // MediumRoutePreference indicates a medium preference, as per + // RFC 4191 section 2.1. + // + // This is the default preference value. + MediumRoutePreference = 0b00 + + // LowRoutePreference indicates a low preference, as per + // RFC 4191 section 2.1. + LowRoutePreference = 0b11 + + // ReservedRoutePreference is a reserved preference value, as per + // RFC 4191 section 2.1. + // + // It MUST NOT be sent. + ReservedRoutePreference = 0b10 +) + // NDPRouterAdvert is an NDP Router Advertisement message. It will only contain // the body of an ICMPv6 packet. // -// See RFC 4861 section 4.2 for more details. +// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details. type NDPRouterAdvert []byte +// As per RFC 4191 section 2.2, +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Reachable Time | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Retrans Timer | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options ... +// +-+-+-+-+-+-+-+-+-+-+-+- const ( // NDPRAMinimumSize is the minimum size of a valid NDP Router // Advertisement message (body of an ICMPv6 packet). @@ -47,6 +107,14 @@ const ( // within the bit-field/flags byte of an NDPRouterAdvert. ndpRAOtherConfFlagMask = (1 << 6) + // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceShift = 3 + + // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift) + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime // field within an NDPRouterAdvert. ndpRARouterLifetimeOffset = 2 @@ -80,6 +148,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool { return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 } +// DefaultRouterPreference returns the Default Router Preference field. +func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference { + return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift) +} + // RouterLifetime returns the lifetime associated with the default router. A // value of 0 means the source of the Router Advertisement is not a default // router and SHOULD NOT appear on the default router list. Note, a value of 0 diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 1b5093e58..8fd1f7d13 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -126,36 +126,83 @@ func TestNDPNeighborAdvert(t *testing.T) { } func TestNDPRouterAdvert(t *testing.T) { - b := []byte{ - 64, 128, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, + tests := []struct { + hopLimit uint8 + managedFlag, otherConfFlag bool + prf NDPRoutePreference + routerLifetimeS uint16 + reachableTimeMS, retransTimerMS uint32 + }{ + { + hopLimit: 1, + managedFlag: false, + otherConfFlag: true, + prf: HighRoutePreference, + routerLifetimeS: 2, + reachableTimeMS: 3, + retransTimerMS: 4, + }, + { + hopLimit: 64, + managedFlag: true, + otherConfFlag: false, + prf: LowRoutePreference, + routerLifetimeS: 258, + reachableTimeMS: 78492, + retransTimerMS: 13213, + }, } - ra := NDPRouterAdvert(b) + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + flags := uint8(0) + if test.managedFlag { + flags |= 1 << 7 + } + if test.otherConfFlag { + flags |= 1 << 6 + } + flags |= uint8(test.prf) << 3 - if got := ra.CurrHopLimit(); got != 64 { - t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) - } + b := []byte{ + test.hopLimit, flags, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS) + binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS) + binary.BigEndian.PutUint32(b[8:], test.retransTimerMS) - if got := ra.ManagedAddrConfFlag(); !got { - t.Errorf("got ManagedAddrConfFlag = false, want = true") - } + ra := NDPRouterAdvert(b) - if got := ra.OtherConfFlag(); got { - t.Errorf("got OtherConfFlag = true, want = false") - } + if got := ra.CurrHopLimit(); got != test.hopLimit { + t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit) + } - if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) - } + if got := ra.ManagedAddrConfFlag(); got != test.managedFlag { + t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag) + } - if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) - } + if got := ra.OtherConfFlag(); got != test.otherConfFlag { + t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag) + } + + if got := ra.DefaultRouterPreference(); got != test.prf { + t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf) + } - if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want { + t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want { + t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want { + t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want) + } + }) } } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index bddb1d0a2..735c28da1 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,7 +41,6 @@ package fdbased import ( "fmt" - "math" "sync/atomic" "golang.org/x/sys/unix" @@ -196,8 +195,12 @@ type Options struct { // option for an FD with a fanoutID already in use by another FD for a different // NIC will return an EINVAL. // +// Since fanoutID must be unique within the network namespace, we start with +// the PID to avoid collisions. The only way to be sure of avoiding collisions +// is to run in a new network namespace. +// // Must be accessed using atomic operations. -var fanoutID int32 = 0 +var fanoutID int32 = int32(unix.Getpid()) // New creates a new fd-based endpoint. // @@ -292,11 +295,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin } switch sa.(type) { case *unix.SockaddrLinklayer: - // See: PACKET_FANOUT_MAX in net/packet/internal.h - const packetFanoutMax = 1 << 16 - if fID > packetFanoutMax { - return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16) - } // Enable PACKET_FANOUT mode if the underlying socket is of type // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will // prevent gvisor from receiving fragmented packets and the host does the @@ -317,7 +315,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin // // See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881 const fanoutType = unix.PACKET_FANOUT_HASH - fanoutArg := int(fID) | fanoutType<<16 + fanoutArg := (int(fID) & 0xffff) | fanoutType<<16 if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index bd63e0289..771b9173a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -88,6 +88,7 @@ type testObject struct { dataCalls int controlCalls int + rawCalls int } // checkValues verifies that the transport protocol, data contents, src & dst @@ -148,6 +149,10 @@ func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpi t.controlCalls++ } +func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { + t.rawCalls++ +} + // Attach is only implemented to satisfy the LinkEndpoint interface. func (*testObject) Attach(stack.NetworkDispatcher) {} @@ -717,7 +722,10 @@ func TestReceive(t *testing.T) { } test.handlePacket(t, ep, &nic) if nic.testObject.dataCalls != 1 { - t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 1 { + t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) } if got := stat.Value(); got != 1 { t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) @@ -968,7 +976,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) + t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 0 { + t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls) } // Send second segment. @@ -977,7 +988,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { }) ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 1 { + t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) } } @@ -1310,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return hdr.View().ToVectorisedView() }, @@ -1351,7 +1365,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) ip.SetHeaderLength(header.IPv4MinimumSize - 1) return hdr.View().ToVectorisedView() @@ -1370,7 +1384,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, @@ -1388,7 +1402,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip).ToVectorisedView() }, @@ -1430,7 +1444,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, Options: ipv4Options, }) return hdr.View().ToVectorisedView() @@ -1469,7 +1483,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, Options: ipv4Options, }) vv := buffer.View(ip).ToVectorisedView() @@ -1515,7 +1529,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return hdr.View().ToVectorisedView() }, @@ -1560,7 +1574,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return hdr.View().ToVectorisedView() }, @@ -1595,7 +1609,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return buffer.View(ip).ToVectorisedView() }, @@ -1630,7 +1644,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 5f6b0c6af..2aa38eb98 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6bee55634..44c85bdb8 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -429,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -614,10 +614,6 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu ipH.SetSourceAddress(r.LocalAddress()) } - // Set the destination. If the packet already included a destination, it will - // be part of the route anyways. - ipH.SetDestinationAddress(r.RemoteAddress()) - // Set the packet ID when zero. if ipH.ID() == 0 { // RFC 6864 section 4.3 mandates uniqueness of ID values for @@ -720,7 +716,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -836,7 +833,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -855,10 +852,17 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { + // Raw socket packets are delivered based solely on the transport protocol + // number. We only require that the packet be valid IPv4, and that they not + // be fragmented. + if !h.More() && h.FragmentOffset() == 0 { + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) + } + pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -920,8 +924,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -995,6 +998,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // to do it here. h.SetTotalLength(uint16(pkt.Data().Size() + len(h))) h.SetFlagsFragmentOffset(0, 0) + + // Now that the packet is reassembled, it can be sent to raw sockets. + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } stats.ip.PacketsDelivered.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 23fc94303..94caaae6c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) { sent := e.stats.icmp.packetsSent received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set. See icmp/protocol.go:protocol.Parse for a full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) if !ok { received.invalid.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index c2e9544c1..7c2a3e56b 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -90,6 +90,10 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } +func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { + // No-op. +} + var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 6103574f7..f5693defe 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -755,9 +755,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -928,10 +928,6 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu ipH.SetSourceAddress(r.LocalAddress()) } - // Set the destination. If the packet already included a destination, it will - // be part of the route anyways. - ipH.SetDestinationAddress(r.RemoteAddress()) - // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. // @@ -991,7 +987,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -1104,7 +1101,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -1123,10 +1120,14 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { + // Raw socket packets are delivered based solely on the transport protocol + // number. We only require that the packet be valid IPv6. + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) + pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1175,8 +1176,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1627,8 +1627,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1669,8 +1669,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 851cd6e75..c44e4ac4e 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -127,25 +127,17 @@ const ( // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt // SLAAC address regenerations in response to an IPv6 endpoint-local conflict. maxSLAACAddrLocalRegenAttempts = 10 -) -var ( // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid // Lifetime to update the valid lifetime of a generated address by // SLAAC. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // Min = 2hrs. MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour // MaxDesyncFactor is the upper bound for the preferred lifetime's desync // factor for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // Must be greater than 0. // // Max = 10m (from RFC 4941 section 5). @@ -154,9 +146,6 @@ var ( // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the // maximum preferred lifetime for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // This value guarantees that a temporary address is preferred for at // least 1hr if the SLAAC prefix is valid for at least that time. MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour @@ -164,9 +153,6 @@ var ( // MinMaxTempAddrValidLifetime is the minimum value allowed for the // maximum valid lifetime for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // This value guarantees that a temporary address is valid for at least // 2hrs if the SLAAC prefix is valid for at least that time. MinMaxTempAddrValidLifetime = 2 * time.Hour @@ -214,28 +200,23 @@ type NDPDispatcher interface { // is also not permitted to call into the stack. OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) - // OnDefaultRouterDiscovered is called when a new default router is - // discovered. Implementations must return true if the newly discovered - // router should be remembered. + // OnOffLinkRouteUpdated is called when an off-link route is updated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool + OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) - // OnDefaultRouterInvalidated is called when a discovered default router that - // was remembered is invalidated. + // OnOffLinkRouteInvalidated is called when an off-link route is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) + OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. - // Implementations must return true if the newly discovered on-link prefix - // should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool + OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that // was remembered is invalidated. @@ -245,13 +226,11 @@ type NDPDispatcher interface { OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) // OnAutoGenAddress is called when a new prefix with its autonomous address- - // configuration flag set is received and SLAAC was performed. Implementations - // may prevent the stack from assigning the address to the NIC by returning - // false. + // configuration flag set is received and SLAAC was performed. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. - OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) // is deprecated, but is still considered valid. Note, if an address is @@ -515,6 +494,8 @@ type ndpState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { + prf header.NDPRoutePreference + // Job to invalidate the default router. // // Must not be nil. @@ -571,11 +552,11 @@ type slaacPrefixState struct { // Must not be nil. invalidationJob *tcpip.Job - // Nonzero only when the address is not valid forever. - validUntil tcpip.MonotonicTime + // nil iff the address is valid forever. + validUntil *tcpip.MonotonicTime - // Nonzero only when the address is not preferred forever. - preferredUntil tcpip.MonotonicTime + // nil iff the address is preferred forever. + preferredUntil *tcpip.MonotonicTime // State associated with the stable address generated for the prefix. stableAddr struct { @@ -733,6 +714,19 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Is the IPv6 endpoint configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { + prf := ra.DefaultRouterPreference() + if prf == header.ReservedRoutePreference { + // As per RFC 4191 section 2.2, + // + // Prf (Default Router Preference) + // + // If the Reserved (10) value is received, the receiver MUST treat the + // value as if it were (00). + // + // Note that the value 00 is the medium (default) router preference value. + prf = header.MediumRoutePreference + } + rtr, ok := ndp.defaultRouters[ip] rl := ra.RouterLifetime() switch { @@ -742,7 +736,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Only remember it if we currently know about less than // MaxDiscoveredDefaultRouters routers. if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { - ndp.rememberDefaultRouter(ip, rl) + ndp.rememberDefaultRouter(ip, rl, prf) } case ok && rl != 0: @@ -750,6 +744,14 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // the invalidation job. rtr.invalidationJob.Cancel() rtr.invalidationJob.Schedule(rl) + + if prf != rtr.prf { + rtr.prf = prf + + // Inform the integrator about router preference updates. + ndp.ep.protocol.options.NDPDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf) + } + ndp.defaultRouters[ip] = rtr case ok && rl == 0: @@ -831,7 +833,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // Let the integrator know a discovered default router is invalidated. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) + ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip) } } @@ -841,20 +843,17 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // The router identified by ip MUST NOT already be known by the IPv6 endpoint. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration, prf header.NDPRoutePreference) { ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { - // Informed by the integrator to not remember the router, do - // nothing further. - return - } + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf) state := defaultRouterState{ + prf: prf, invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateDefaultRouter(ip) }), @@ -878,11 +877,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { - // Informed by the integrator to not remember the prefix, do - // nothing further. - return - } + ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) state := onLinkPrefixState{ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { @@ -1055,7 +1050,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // The time an address is preferred until is needed to properly generate the // address. if pl < header.NDPInfiniteLifetime { - state.preferredUntil = now.Add(pl) + t := now.Add(pl) + state.preferredUntil = &t } if !ndp.generateSLAACAddr(prefix, &state) { @@ -1073,7 +1069,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if vl < header.NDPInfiniteLifetime { state.invalidationJob.Schedule(vl) - state.validUntil = now.Add(vl) + t := now.Add(vl) + state.validUntil = &t } // If the address is assigned (DAD resolved), generate a temporary address. @@ -1096,16 +1093,13 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { - // Informed by the integrator not to add the address. - return nil - } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } + ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) + return addressEndpoint } @@ -1181,7 +1175,8 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt state.stableAddr.localGenerationFailures++ } - if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, ndp.ep.protocol.stack.Clock().NowMonotonic().Sub(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + deprecated := state.preferredUntil != nil && !state.preferredUntil.After(ndp.ep.protocol.stack.Clock().NowMonotonic()) + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, deprecated); addressEndpoint != nil { state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true @@ -1242,7 +1237,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // address is the lower of the valid lifetime of the stable address or the // maximum temporary address valid lifetime. vl := ndp.configs.MaxTempAddrValidLifetime - if prefixState.validUntil != (tcpip.MonotonicTime{}) { + if prefixState.validUntil != nil { if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { vl = prefixVL } @@ -1258,7 +1253,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // maximum temporary address preferred lifetime - the temporary address desync // factor. pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor - if prefixState.preferredUntil != (tcpip.MonotonicTime{}) { + if prefixState.preferredUntil != nil { if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { // Respect the preferred lifetime of the prefix, as per RFC 4941 section // 3.3 step 4. @@ -1400,9 +1395,10 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat if !deprecated { prefixState.deprecationJob.Schedule(pl) } - prefixState.preferredUntil = now.Add(pl) + t := now.Add(pl) + prefixState.preferredUntil = &t } else { - prefixState.preferredUntil = tcpip.MonotonicTime{} + prefixState.preferredUntil = nil } // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: @@ -1420,14 +1416,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // Handle the infinite valid lifetime separately as we do not schedule a // job in this case. prefixState.invalidationJob.Cancel() - prefixState.validUntil = tcpip.MonotonicTime{} + prefixState.validUntil = nil } else { var effectiveVl time.Duration var rl time.Duration // If the prefix was originally set to be valid forever, assume the // remaining time to be the maximum possible value. - if prefixState.validUntil == (tcpip.MonotonicTime{}) { + if prefixState.validUntil == nil { rl = header.NDPInfiniteLifetime } else { rl = prefixState.validUntil.Sub(now) @@ -1442,7 +1438,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat if effectiveVl != 0 { prefixState.invalidationJob.Cancel() prefixState.invalidationJob.Schedule(effectiveVl) - prefixState.validUntil = now.Add(effectiveVl) + t := now.Add(effectiveVl) + prefixState.validUntil = &t } } @@ -1462,8 +1459,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // maximum temporary address valid lifetime. Note, the valid lifetime of a // temporary address is relative to the address's creation time. validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) - if prefixState.validUntil != (tcpip.MonotonicTime{}) && validUntil.Sub(prefixState.validUntil) > 0 { - validUntil = prefixState.validUntil + if prefixState.validUntil != nil && prefixState.validUntil.Before(validUntil) { + validUntil = *prefixState.validUntil } // If the address is no longer valid, invalidate it immediately. Otherwise, @@ -1482,14 +1479,15 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // desync factor. Note, the preferred lifetime of a temporary address is // relative to the address's creation time. preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) - if prefixState.preferredUntil != (tcpip.MonotonicTime{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { - preferredUntil = prefixState.preferredUntil + if prefixState.preferredUntil != nil && prefixState.preferredUntil.Before(preferredUntil) { + preferredUntil = *prefixState.preferredUntil } // If the address is no longer preferred, deprecate it immediately. // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) tempAddrState.deprecationJob.Cancel() + if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) } else { @@ -1859,9 +1857,7 @@ func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) - if MaxDesyncFactor != 0 { - ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor))) - } + ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor))) } func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 3438deb79..130438f3b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -42,24 +42,21 @@ type testNDPDispatcher struct { func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { +func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) { t.addr = addr - return true } -func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { +func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) { t.addr = addr } -func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { } -func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return false +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { @@ -96,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { ipv6EP := ep.(*endpoint) ipv6EP.mu.Lock() - ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) + ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour, header.MediumRoutePreference) ipv6EP.mu.Unlock() if ndpDisp.addr != lladdr1 { diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 395ff9a07..e0847e58a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -95,7 +95,7 @@ go_library( go_test( name = "stack_x_test", - size = "medium", + size = "small", srcs = [ "addressable_endpoint_state_test.go", "ndp_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index f7fbcbaa7..18e0d4374 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -35,7 +35,6 @@ import ( // Currently, only TCP tracking is supported. // Our hash table has 16K buckets. -// TODO(gvisor.dev/issue/170): These should be tunable. const numBuckets = 1 << 14 // Direction of the tuple. @@ -165,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. - // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle - // other tcp states. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) } else if hook == cn.tcbHook { @@ -246,8 +243,7 @@ func (ct *ConnTrack) init() { // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support -// other transport protocols. +// TODO(gvisor.dev/issue/6168): Support UDP. func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { tid, err := packetToTupleID(pkt) if err != nil { @@ -385,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/170): Support other transport protocols. + // TODO(gvisor.dev/issue/6168): Support UDP. if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -466,8 +462,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } // Update the state of tcb. - // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle - // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() @@ -544,8 +538,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused returns the next bucket that should be checked and the time after // which it should be called again. func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { - // TODO(gvisor.dev/issue/170): This can be more finely controlled, as - // it is in Linux via sysctl. const fractionPerReaping = 128 const maxExpiredPct = 50 const maxFullTraversal = 60 * time.Second diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 0a26f6dd8..f152c0d83 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -268,10 +268,6 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from -// which address can be gathered. Currently, address is only needed for -// prerouting. -// // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 2812c89aa..91e266de8 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than // forwarded). -// -// TODO(gvisor.dev/issue/170): Other flags need to be added after we support -// them. type RedirectTarget struct { // Port indicates port used to redirect. It is immutable. Port uint16 @@ -100,9 +97,6 @@ type RedirectTarget struct { } // Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for Prerouting and calls pkt.Clone(), neither -// of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { @@ -136,8 +130,6 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r panic("redirect target is supported only on output and prerouting hooks") } - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 93592e7f5..66e5f22ac 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -242,7 +242,6 @@ type IPHeaderFilter struct { func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( - // TODO(gvisor.dev/issue/170): Support other filter fields. transProto tcpip.TransportProtocolNumber dstAddr tcpip.Address srcAddr tcpip.Address @@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return true case Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 133bacdd0..bd512d312 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -52,17 +52,6 @@ const ( linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") defaultPrefixLen = 128 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 10 * time.Second - - // Extra time to use when waiting for an async event to not occur. - // - // Since a negative check is used to make sure an event did not happen, it is - // okay to use a smaller timeout compared to the positive case since execution - // stall in regards to the monotonic clock will not affect the expected - // outcome. - defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -112,11 +101,13 @@ type ndpDADEvent struct { res stack.DADResult } -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool +type ndpOffLinkRouteEvent struct { + nicID tcpip.NICID + subnet tcpip.Subnet + router tcpip.Address + prf header.NDPRoutePreference + // true if route was updated, false if invalidated. + updated bool } type ndpPrefixEvent struct { @@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct { eventType ndpAutoGenAddrEventType } +func (e ndpAutoGenAddrEvent) String() string { + return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType) +} + type ndpRDNSS struct { addrs []tcpip.Address lifetime time.Duration @@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // related events happen for test purposes. type ndpDispatcher struct { dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool + offLinkRouteC chan ndpOffLinkRouteEvent prefixC chan ndpPrefixEvent - rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent dnsslC chan ndpDNSSLEvent @@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add } } -// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. +func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) { + if c := n.offLinkRouteC; c != nil { + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, true, } } - - return n.rememberRouter } -// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. +func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { + if c := n.offLinkRouteC; c != nil { + var prf header.NDPRoutePreference + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, false, } } } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ nicID, @@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip true, } } - - return n.rememberPrefix } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. @@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ nicID, @@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi newAddr, } } - return true } func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { @@ -1039,9 +1032,12 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options -// and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { +// raBuf returns a valid NDP Router Advertisement with options, router +// preference and DHCPv6 configurations specified. +func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { + const flagsByte = 1 + const routerLifetimeOffset = 2 + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -1050,19 +1046,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[2:], rl) + binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl) // Populate the Managed Address flag field. if managedAddress { - // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) - // of the RA payload. - raPayload[1] |= 1 << 7 + // The Managed Addresses flag field is the 7th bit of the flags byte. + raPayload[flagsByte] |= 1 << 7 } // Populate the Other Configurations flag field. if otherConfigurations { - // The Other Configurations flag field is the 6th bit of byte #1 - // (0-indexing) of the RA payload. - raPayload[1] |= 1 << 6 + // The Other Configurations flag field is the 6th bit of the flags byte. + raPayload[flagsByte] |= 1 << 6 } + // The Prf field is held in the flags byte. + raPayload[flagsByte] |= byte(prf) << 3 opts := ra.Options() opts.Serialize(optSer) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -1090,7 +1086,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer) } // raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related @@ -1098,18 +1094,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer { + return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{}) } // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { +func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } +// raBufWithPrf returns a valid NDP Router Advertisement with a preference. +// +// Note, raBufWithPrf does not populate any of the RA fields other than the +// Router Lifetime and Default Router Preference fields. +func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{}) +} + // raBufWithPI returns a valid NDP Router Advertisement with a single Prefix // Information option. // @@ -1169,7 +1173,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { config: func(enable bool) ipv6.NDPConfigurations { return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} }, - ra: raBuf(llAddr2, 1000), + ra: raBufSimple(llAddr2, 1000), }, { name: "No Prefix Discovery", @@ -1205,9 +1209,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } ndpConfigs := test.config(enable) ndpConfigs.HandleRAs = handle @@ -1277,8 +1281,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) default: } select { @@ -1304,10 +1308,8 @@ func boolToUint64(v bool) uint64 { return 0 } -// Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { + return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) } func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { @@ -1340,57 +1342,12 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b } } -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - default: - } -} - func TestRouterDiscovery(t *testing.T) { + const nicID = 1 + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) clock := faketime.NewManualClock() @@ -1405,27 +1362,28 @@ func TestRouterDiscovery(t *testing.T) { Clock: clock, }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { + expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { t.Helper() select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") } } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { t.Helper() clock.Advance(timeout) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + var prf header.NDPRoutePreference + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("timed out waiting for router discovery event") @@ -1436,37 +1394,44 @@ func TestRouterDiscovery(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Rx an RA from lladdr2 with zero lifetime. It should not be // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with 0 lifetime") default: } - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from lladdr2 with a huge lifetime and reserved preference value + // (which should be interpreted as the default (medium) preference value). + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - // Rx an RA from another router (lladdr3) with non-zero lifetime. + // Rx an RA from another router (lladdr3) with non-zero lifetime and + // non-default preference value. const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) + expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) - // Rx an RA from lladdr2 with lesser lifetime. + // Rx an RA from lladdr2 with lesser lifetime and default (medium) + // preference value. const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds)) select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") + case <-ndpDisp.offLinkRouteC: + t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") default: } + // Rx an RA from lladdr2 with a different preference. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. @@ -1474,15 +1439,15 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) + expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) @@ -1491,16 +1456,17 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) + expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) }) } // TestRouterDiscoveryMaxRouters tests that only // ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1513,8 +1479,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { })}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Receive an RA from 2 more than the max number of discovered routers. @@ -1523,13 +1489,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) if i <= ipv6.MaxDiscoveredDefaultRouters { select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") @@ -1537,7 +1503,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } else { select { - case <-ndpDisp.routerC: + case <-ndpDisp.offLinkRouteC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") default: } @@ -1551,54 +1517,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) } -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - default: - } -} - func TestPrefixDiscovery(t *testing.T) { prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") @@ -1606,8 +1524,7 @@ func TestPrefixDiscovery(t *testing.T) { testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) clock := faketime.NewManualClock() @@ -1697,17 +1614,6 @@ func TestPrefixDiscovery(t *testing.T) { } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - prefix := tcpip.AddressWithPrefix{ Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, @@ -1715,8 +1621,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) clock := faketime.NewManualClock() @@ -1750,9 +1655,9 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) expectPrefixEvent(subnet, true) - clock.Advance(testInfiniteLifetime) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1760,9 +1665,8 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - clock.Advance(testInfiniteLifetime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) + clock.Advance(header.NDPInfiniteLifetime - time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, false); diff != "" { @@ -1773,23 +1677,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - clock.Advance(testInfiniteLifetime) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - default: - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) - clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1806,8 +1700,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1884,17 +1777,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) } +const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second) +const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second) + // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1903,6 +1791,7 @@ func TestAutoGenAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1911,6 +1800,7 @@ func TestAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { @@ -1960,8 +1850,9 @@ func TestAutoGenAddr(t *testing.T) { default: } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -1971,7 +1862,7 @@ func TestAutoGenAddr(t *testing.T) { } // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") @@ -1979,12 +1870,13 @@ func TestAutoGenAddr(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -2014,20 +1906,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t // TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when // configured to do so as part of IPv6 Privacy Extensions. func TestAutoGenTempAddr(t *testing.T) { - const ( - nicID = 1 - newMinVL = 5 - newMinVLDuration = newMinVL * time.Second - ) - - savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - ipv6.MaxDesyncFactor = time.Nanosecond + const nicID = 1 prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -2047,218 +1926,211 @@ func TestAutoGenTempAddr(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for i, test := range tests { - i := i - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + for i, test := range tests { + t.Run(test.name, func(t *testing.T) { + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + }, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, + Clock: clock, + }) - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("expected addr auto gen event") } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for DAD event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred - // lifetimes. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum and won't be reached in this test. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 } - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } - }) + default: + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } } // TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not @@ -2266,12 +2138,6 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = time.Nanosecond - tests := []struct { name string dupAddrTransmits uint8 @@ -2287,66 +2153,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, + })}, + Clock: clock, + }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") + default: + t.Fatal("expected addr auto gen event") + } + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD event") + } - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - }) + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + default: + } + }) + } } // TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address @@ -2359,12 +2215,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = 0 - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) @@ -2375,6 +2225,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -2388,6 +2239,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2417,12 +2269,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // Wait for DAD to complete for the stable address then expect the temporary // address to be generated. + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } select { @@ -2430,7 +2283,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -2439,46 +2292,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // regenerated. func TestAutoGenTempAddrRegen(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrValidLifetime: maxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate, + } + clock := faketime.NewManualClock() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2501,36 +2352,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" { t.Fatal(mismatch) } + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false @@ -2541,45 +2399,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) { ndpEP.SetNDPConfigurations(ndpConfigs) } + // Refresh lifetimes and wait for the last temporary address to be deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv) + + // Refresh lifetimes such that the prefix is valid and preferred forever. + // + // This should not affect the lifetimes of temporary addresses because they + // are capped by the maximum valid and preferred lifetimes for temporary + // addresses. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) + // Wait for all the temporary addresses to get invalidated. - tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} - invalidateAfter := newMinVLDuration - 2*regenAfter + invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{}) for _, addr := range tempAddrs { - // Wait for a deprecation then invalidation event, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation jobs could execute in any - // order. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we shouldn't get a deprecation - // event after. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event = %+v", e) - } - case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - - invalidateAfter = regenAfter + expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter) + invalidateAfter = tempAddrRegenenerationTime } - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" { t.Fatal(mismatch) } } @@ -2588,52 +2425,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // regeneration job gets updated when refreshing the address's lifetimes. func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate + maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second) + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime, + MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2, + } + clock := faketime.NewManualClock() + initialTime := clock.NowMonotonic() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -2650,22 +2489,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2673,13 +2513,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // // A new temporary address should be generated after the regeneration // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0)) expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: + } + + effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + + // Advance the clock by the regeneration time but don't expect a new temporary + // address as the prefix is deprecated. + clock.Advance(tempAddrRegenenerationTime) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: } // Prefer the prefix again. @@ -2687,8 +2541,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr2, newAddr) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) + expectAutoGenAddrEvent(tempAddrs[1], newAddr) + // Wait for the first temporary address to be deprecated. + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %s", e) + default: + } // Increase the maximum lifetimes for temporary addresses to large values // then refresh the lifetimes of the prefix. @@ -2699,34 +2560,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // regenerate a new temporary address. Note, new addresses are only // regenerated after the preferred lifetime - the regenerate advance duration // as paased. - ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second - ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + const largeLifetimeSeconds = minVLSeconds * 2 + const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second + ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime + ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } ndpEP := ipv6Ep.(ipv6.NDPEndpoint) ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime) + clock.Advance(largeLifetime - timeSinceInitialTime) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) + // to offset the advement of time to test the first temporary address's + // deprecation after the second was generated + advLess := regenAdv + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv)) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + default: } - - // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration job gets scheduled again. - // - // The maximum lifetime is the sum of the minimum lifetimes for temporary - // addresses + the time that has already passed since the last address was - // generated so that the regeneration job is needed to generate the next - // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout - ndpConfigs.MaxTempAddrValidLifetime = newLifetimes - ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2954,13 +2811,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2970,6 +2828,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2983,7 +2842,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } - return ndpDisp, e, s + return ndpDisp, e, s, clock } // addrForNewConnectionTo returns the local address used when creating a new @@ -3057,7 +2916,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3160,19 +3019,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // when its preferred lifetime expires. func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3190,12 +3041,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -3213,7 +3065,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) @@ -3232,7 +3084,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3241,7 +3093,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3251,6 +3103,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since addr1 is deprecated but // addr2 is not. expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { @@ -3259,7 +3112,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3271,7 +3124,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3281,7 +3134,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3295,7 +3148,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3305,7 +3158,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr2) // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3317,6 +3170,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // cases because the deprecation and invalidation handlers could be handled in // either deprecation then invalidation, or invalidation then deprecation // (which should be cancelled by the invalidation handler). + // + // Since we're about to cause both events to fire, we need the dispatcher + // channel to be able to hold both. + if got, want := len(ndpDisp.autoGenAddrC), 0; got != want { + t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want { + t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2) + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { @@ -3327,21 +3191,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation + // If we get an invalidation event first, we should not get a deprecation // event after. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3378,15 +3242,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // infinite values. func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3410,68 +3265,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + default: + t.Fatal("expected addr auto gen event") + } - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an @@ -3479,12 +3324,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 4 - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3495,137 +3334,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { evl uint32 }{ // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than + // new VL is less than minVLSeconds but was originally greater than // it. { "LargeVLToVLLessThanMinVLForUpdate", 9999, 1, - newMinVL, + minVLSeconds, }, { "LargeVLTo0", 9999, 0, - newMinVL, + minVLSeconds, }, { "InfiniteVLToVLLessThanMinVLForUpdate", infiniteVL, 1, - newMinVL, + minVLSeconds, }, { "InfiniteVLTo0", infiniteVL, 0, - newMinVL, + minVLSeconds, }, - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. + // Should not update VL if original VL was less than minVLSeconds + // and the new VL is also less than minVLSeconds. { "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, + minVLSeconds - 1, + minVLSeconds - 3, + minVLSeconds - 1, }, // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. + // remaining time or is greater than minVLSeconds. { "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, + minVLSeconds + 5, + minVLSeconds + 3, + minVLSeconds + 3, }, { "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, + minVLSeconds - 3, + minVLSeconds - 1, + minVLSeconds - 1, }, { "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, + minVLSeconds - 3, + minVLSeconds + 1, + minVLSeconds + 1, }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - // - // Validate that the VL for the address got set - // to test.evl. - // + // + // Validate that the VL for the address got set + // to test.evl. + // - // The address should not be invalidated until the effective valid - // lifetime has passed. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): - } + // The address should not be invalidated until the effective valid + // lifetime has passed. + const delta = 1 + clock.Advance(time.Duration(test.evl)*time.Second - delta) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + default: + } - // Wait for the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + // Wait for the invalidation event. + clock.Advance(delta) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrRemoval tests that when auto-generated addresses are removed @@ -3696,7 +3527,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3976,13 +3807,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 - // Needed for the temporary address sub test. - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MaxDesyncFactor = time.Nanosecond - secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4008,22 +3832,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { @@ -4034,15 +3860,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } } @@ -4053,7 +3880,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name string ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool - prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { @@ -4062,7 +3889,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, - prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { // Receive an RA with prefix1 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) return nil @@ -4076,7 +3903,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name: "LinkLocal address", ndpConfigs: ipv6.NDPConfigurations{}, autoGenLinkLocal: true, - prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { return nil }, addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { @@ -4090,14 +3917,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, - prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { header.InitialTempIID(tempIIDHistory, nil, nicID) // Generate a stable SLAAC address so temporary addresses will be // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) + expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -4109,14 +3936,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } for _, addrType := range addrTypes { - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the parallel - // tests complete and limit the number of parallel tests running at the same - // time to reduce flakes. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. t.Run(addrType.name, func(t *testing.T) { for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { @@ -4125,8 +3944,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrType := addrType t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), @@ -4134,6 +3951,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { e := channel.New(0, 1280, linkAddr1) ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4150,6 +3968,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4157,12 +3976,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:]) // Simulate DAD conflicts so the address is regenerated. for i := uint8(0); i < numFailures; i++ { addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) // Should not have any new addresses assigned to the NIC. if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { @@ -4172,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4182,7 +4001,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4194,8 +4013,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // an address after DAD resolves. if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4205,7 +4024,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4718,11 +4537,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ) ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ @@ -4765,17 +4582,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ), ) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) } select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) @@ -4797,8 +4614,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpected router event = %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpected off-link route event = %#v", e) default: } select { @@ -4884,11 +4701,9 @@ func TestCleanupNDPState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } clock := faketime.NewManualClock() s := stack.New(stack.Options{ @@ -4905,14 +4720,14 @@ func TestCleanupNDPState(t *testing.T) { Clock: clock, }) - expectRouterEvent := func() (bool, ndpRouterEvent) { + expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { select { - case e := <-ndpDisp.routerC: + case e := <-ndpDisp.offLinkRouteC: return true, e default: } - return false, ndpRouterEvent{} + return false, ndpOffLinkRouteEvent{} } expectPrefixEvent := func() (bool, ndpPrefixEvent) { @@ -4957,8 +4772,8 @@ func TestCleanupNDPState(t *testing.T) { // multiple addresses. e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -4968,8 +4783,8 @@ func TestCleanupNDPState(t *testing.T) { } e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -4979,8 +4794,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -4990,8 +4805,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -5032,14 +4847,14 @@ func TestCleanupNDPState(t *testing.T) { test.cleanupFn(t, s) // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) + gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() + ok, e := expectOffLinkRouteEvent() if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) break } - gotRouterEvents[e]++ + gotOffLinkRouteEvents[e]++ } gotPrefixEvents := make(map[ndpPrefixEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { @@ -5066,14 +4881,14 @@ func TestCleanupNDPState(t *testing.T) { t.FailNow() } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { + t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) } expectedPrefixEvents := map[ndpPrefixEvent]int{ {nicID: nicID1, prefix: subnet1, discovered: false}: 1, @@ -5137,8 +4952,8 @@ func TestCleanupNDPState(t *testing.T) { // cancelled when the NDP state was cleaned up). clock.Advance(lifetimeSeconds * time.Second) select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") + case <-ndpDisp.offLinkRouteC: + t.Error("unexpected off-link route event") default: } select { @@ -5163,7 +4978,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { ndpDisp := ndpDispatcher{ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 378389db2..b854d868c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -779,17 +779,11 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt transProto := state.proto - // Raw socket packets are delivered based solely on the transport - // protocol number. We do not inspect the payload to ensure it's - // validly formed. - n.stack.demux.deliverRawPacket(protocol, pkt) - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. @@ -878,6 +872,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo } } +// DeliverRawPacket implements TransportDispatcher. +func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { + // For ICMPv4 only we validate the header length for compatibility with + // raw(7) ICMP_FILTER. The same check is made in Linux here: + // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189. + if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize { + return + } + n.stack.demux.deliverRawPacket(protocol, pkt) +} + // ID implements NetworkInterface. func (n *nic) ID() tcpip.NICID { return n.id diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 4ca702121..9192d8433 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -134,7 +134,7 @@ type PacketBuffer struct { // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType - // NICID is the ID of the interface the network packet was received at. + // NICID is the ID of the last interface the network packet was handled at. NICID tcpip.NICID // RXTransportChecksumValidated indicates that transport checksum verification diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index a038389e0..dfe2c886f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -265,6 +265,11 @@ type TransportDispatcher interface { // // DeliverTransportError takes ownership of the packet buffer. DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) + + // DeliverRawPacket delivers a packet to any subscribed raw sockets. + // + // DeliverRawPacket does NOT take ownership of the packet buffer. + DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 40d277312..81fabe29a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -108,7 +108,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. - // TODO(gvisor.dev/issue/170): S/R this field. + // TODO(gvisor.dev/issue/4595): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -1872,9 +1872,8 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 8a8454a6a..dda57e225 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" @@ -215,10 +216,17 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t netProto: netProto, transProto: transProto, } - epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, flags) + if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil { + return err + } + // Only add this newly created multiportEndpoint if the singleRegisterEndpoint + // succeeded. + if !ok { + epsByNIC.endpoints[bindToDevice] = multiPortEp + } + return nil } func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { @@ -405,7 +413,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { @@ -468,17 +475,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNIC, ok := eps.endpoints[id] if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), seed: d.stack.Seed(), } + } + if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { + return err + } + // Only add this newly created epsByNIC if registerEndpoint succeeded. + if !ok { eps.endpoints[id] = epsByNIC } - - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) + return nil } func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 0972c94de..45b09110d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -203,6 +203,56 @@ func TestTransportDemuxerRegister(t *testing.T) { } } +func TestTransportDemuxerRegisterMultiple(t *testing.T) { + type test struct { + flags ports.Flags + want tcpip.Error + } + for _, subtest := range []struct { + name string + tests []test + }{ + {"zeroFlags", []test{ + {ports.Flags{}, nil}, + {ports.Flags{}, &tcpip.ErrPortInUse{}}, + }}, + {"multibindFlags", []test{ + // Allow multiple registrations same TransportEndpointID with multibind flags. + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + // Disallow registration w/same ID for a non-multibindflag. + {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, + }}, + } { + t.Run(subtest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + var eps []tcpip.Endpoint + for idx, test := range subtest.tests { + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + eps = append(eps, ep) + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + id := stack.TransportEndpointID{LocalPort: 1} + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { + t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) + } + } + for _, ep := range eps { + ep.Close() + } + }) + } +} + // TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 91622fa4c..8f2658f64 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -465,11 +465,11 @@ type ControlMessages struct { // PacketOwner is used to get UID and GID of the packet. type PacketOwner interface { - // UID returns UID of the packet. - UID() uint32 + // UID returns KUID of the packet. + KUID() uint32 - // GID returns GID of the packet. - GID() uint32 + // GID returns KGID of the packet. + KGID() uint32 } // ReadOptions contains options for Endpoint.Read. diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 07ba2b837..f9ab7d0af 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) { // Make sure the packet is not dropped by the next rule. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr checker.ICMPv6Type(header.ICMPv6EchoReply))) } +func boolToInt(v bool) uint64 { + if v { + return 1 + } + return 0 +} + +func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[hook] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } +} + func TestForwardingHook(t *testing.T) { const ( nicID1 = 1 @@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) { }, } - setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[stack.Forward] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } - } - - boolToInt := func(v bool) uint64 { - if v { - return 1 - } - return 0 - } - subTests := []struct { name string setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) @@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) { { name: "Drop", - setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), expectForward: false, }, { name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), expectForward: false, }, { name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), expectForward: true, }, { name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), expectForward: true, }, { name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), expectForward: true, }, } @@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) { }) } } + +func TestInputHookWithLocalForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + rx func(*channel.Endpoint) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv4(t, b, + checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) + }, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv6(t, b, + checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv6Addr), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) + }, + }, + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectDrop bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectDrop: false, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on arrival NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on delivered NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), + expectDrop: false, + }, + + { + name: "Drop with input NIC filtering on other NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), + expectDrop: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + }) + + test.rx(e1) + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { + t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) + } + if got := ip2Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) + } + + if p, ok := e1.Read(); ok == subTest.expectDrop { + t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) + } else if !subTest.expectDrop { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + if p, ok := e2.Read(); ok { + t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 87d36e1dd..b2008f0b2 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -44,20 +44,17 @@ type ndpDispatcher struct{} func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { - return false +func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) { } -func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} +func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {} -func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} -func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return true +func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index fb77febcf..cb316d27a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -758,8 +758,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -767,8 +765,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 47f7dd1cb..fa82affc1 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -123,8 +123,6 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. - // // Right now, the Parse() method is tied to enabled protocols passed into // stack.New. This works for UDP and TCP, but we handle ICMP traffic even // when netstack users don't pass ICMP as a supported protocol. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index cd8c99d41..8e7bb6c6e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -208,7 +208,6 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul } func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { - // TODO(gvisor.dev/issue/173): Implement. return 0, &tcpip.ErrInvalidOptionValue{} } @@ -244,8 +243,6 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi // Bind implements tcpip.Endpoint.Bind. func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - // TODO(gvisor.dev/issue/173): Add Bind support. - // "By default, all packets of the specified protocol type are passed // to a packet socket. To get packets only from a specific interface // use bind(2) specifying an address in a struct sockaddr_ll to bind @@ -385,7 +382,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(gvisor.dev/issue/173): Return network protocol. if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader().View()) @@ -424,7 +420,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, default: panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) } - } else { // Raw packets need their ethernet headers prepended before // queueing. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 1bce2769a..b6687911a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -286,26 +286,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return nil, nil, nil, &tcpip.ErrBadBuffer{} } - // If this is an unassociated socket and callee provided a nonzero - // destination address, route using that address. - if e.ops.GetHeaderIncluded() { - ip := header.IPv4(payloadBytes) - if !ip.IsValid(len(payloadBytes)) { - return nil, nil, nil, &tcpip.ErrInvalidOptionValue{} - } - dstAddr := ip.DestinationAddress() - // Update dstAddr with the address in the IP header, unless - // opts.To is set (e.g. if sendto specifies a specific - // address). - if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { - opts.To = &tcpip.FullAddress{ - NIC: 0, // NIC is unset. - Addr: dstAddr, // The address from the payload. - Port: 0, // There are no ports here. - } - } - } - // Did the user caller provide a destination? If not, use the connected // destination. if opts.To == nil { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2c65b737d..d807b13b7 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } switch { + case s.flags.Contains(header.TCPFlagRst): + e.stack.Stats().DroppedPackets.Increment() + return nil + case s.flags == header.TCPFlagSyn: if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() @@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() return nil - case (s.flags & header.TCPFlagAck) != 0: + case s.flags.Contains(header.TCPFlagAck): if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err mss: rcvdSynOptions.MSS, }) + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) { + s.incRef() + n.newSegmentWaker.Assert() + } + // Do the delivery in a separate goroutine so // that we don't block the listen loop in case // the application is slow to accept or stops @@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil default: + e.stack.Stats().DroppedPackets.Increment() return nil } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 570e5081c..f86450298 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -406,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.ep.transitionToStateEstablishedLocked(h) - // If the segment has data then requeue it for the receiver - // to process it again once main loop is started. - if s.data.Size() > 0 { + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) { s.incRef() - h.ep.enqueueSegment(s) + h.ep.newSegmentWaker.Assert() } return nil } @@ -1130,7 +1130,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - if e.EndpointState().closed() { + if state := e.EndpointState(); state.closed() || state == StateTimeWait { return nil } s := e.segmentQueue.dequeue() @@ -1474,11 +1474,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return &tcpip.ErrConnectionReset{} } - if n¬ifyClose != 0 && closeTimer == nil { - if e.EndpointState() == StateFinWait2 && e.closed { + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is still + // established. This can occur when notifyClose + // was asserted just before becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + if closeTimer == nil { + closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 661ca604a..9ce8fcae9 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -559,7 +559,6 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { - return false, true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index e7ede7662..9bbe9bc3e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6238,6 +6238,54 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } +func TestListenDropIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + c.Create(-1 /*epRcvBuf*/) + + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1 /*backlog*/); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + initialDropped := stats.DroppedPackets.Value() + + // Send RST, FIN segments, that are expected to be dropped by the listener. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin, + }) + + // To ensure that the RST, FIN sent earlier are indeed received and ignored + // by the listener, send a SYN and wait for the SYN to be ACKd. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + )) + + if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { + t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) + } +} + func TestEndpointBindListenAcceptState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD index 3dba36f12..d7decd78a 100644 --- a/pkg/usermem/BUILD +++ b/pkg/usermem/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "bytes_io.go", "bytes_io_unsafe.go", + "marshal.go", "usermem.go", ], visibility = ["//:sandbox"], diff --git a/pkg/usermem/marshal.go b/pkg/usermem/marshal.go new file mode 100644 index 000000000..5b5a662dc --- /dev/null +++ b/pkg/usermem/marshal.go @@ -0,0 +1,43 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package usermem + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/hostarch" +) + +// IOCopyContext wraps an object implementing hostarch.IO to implement +// marshal.CopyContext. +type IOCopyContext struct { + Ctx context.Context + IO IO + Opts IOOpts +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) { + return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) { + return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) +} |