diff options
Diffstat (limited to 'pkg')
35 files changed, 837 insertions, 365 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 064a54547..a461bb65e 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -85,6 +85,8 @@ go_library( go_test( name = "linux_test", size = "small", - srcs = ["netfilter_test.go"], + srcs = [ + "netfilter_test.go", + ], library = ":linux", ) diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errors.go index 93f85a864..b08b2687e 100644 --- a/pkg/abi/linux/errors.go +++ b/pkg/abi/linux/errors.go @@ -15,158 +15,149 @@ package linux // Errno represents a Linux errno value. -type Errno struct { - number int - name string -} - -// Number returns the errno number. -func (e *Errno) Number() int { - return e.number -} - -// String implements fmt.Stringer.String. -func (e *Errno) String() string { - return e.name -} +type Errno int // Errno values from include/uapi/asm-generic/errno-base.h. -var ( - EPERM = &Errno{1, "operation not permitted"} - ENOENT = &Errno{2, "no such file or directory"} - ESRCH = &Errno{3, "no such process"} - EINTR = &Errno{4, "interrupted system call"} - EIO = &Errno{5, "I/O error"} - ENXIO = &Errno{6, "no such device or address"} - E2BIG = &Errno{7, "argument list too long"} - ENOEXEC = &Errno{8, "exec format error"} - EBADF = &Errno{9, "bad file number"} - ECHILD = &Errno{10, "no child processes"} - EAGAIN = &Errno{11, "try again"} - ENOMEM = &Errno{12, "out of memory"} - EACCES = &Errno{13, "permission denied"} - EFAULT = &Errno{14, "bad address"} - ENOTBLK = &Errno{15, "block device required"} - EBUSY = &Errno{16, "device or resource busy"} - EEXIST = &Errno{17, "file exists"} - EXDEV = &Errno{18, "cross-device link"} - ENODEV = &Errno{19, "no such device"} - ENOTDIR = &Errno{20, "not a directory"} - EISDIR = &Errno{21, "is a directory"} - EINVAL = &Errno{22, "invalid argument"} - ENFILE = &Errno{23, "file table overflow"} - EMFILE = &Errno{24, "too many open files"} - ENOTTY = &Errno{25, "not a typewriter"} - ETXTBSY = &Errno{26, "text file busy"} - EFBIG = &Errno{27, "file too large"} - ENOSPC = &Errno{28, "no space left on device"} - ESPIPE = &Errno{29, "illegal seek"} - EROFS = &Errno{30, "read-only file system"} - EMLINK = &Errno{31, "too many links"} - EPIPE = &Errno{32, "broken pipe"} - EDOM = &Errno{33, "math argument out of domain of func"} - ERANGE = &Errno{34, "math result not representable"} +const ( + NOERRNO = iota + EPERM + ENOENT + ESRCH + EINTR + EIO + ENXIO + E2BIG + ENOEXEC + EBADF + ECHILD // 10 + EAGAIN + ENOMEM + EACCES + EFAULT + ENOTBLK + EBUSY + EEXIST + EXDEV + ENODEV + ENOTDIR // 20 + EISDIR + EINVAL + ENFILE + EMFILE + ENOTTY + ETXTBSY + EFBIG + ENOSPC + ESPIPE + EROFS // 30 + EMLINK + EPIPE + EDOM + ERANGE + // Errno values from include/uapi/asm-generic/errno.h. + EDEADLK + ENAMETOOLONG + ENOLCK + ENOSYS + ENOTEMPTY + ELOOP //40 + _ // Skip for EWOULDBLOCK = EAGAIN + ENOMSG //42 + EIDRM + ECHRNG + EL2NSYNC + EL3HLT + EL3RST + ELNRNG + EUNATCH + ENOCSI + EL2HLT // 50 + EBADE + EBADR + EXFULL + ENOANO + EBADRQC + EBADSLT + _ // Skip for EDEADLOCK = EDEADLK + EBFONT + ENOSTR // 60 + ENODATA + ETIME + ENOSR + ENONET + ENOPKG + EREMOTE + ENOLINK + EADV + ESRMNT + ECOMM // 70 + EPROTO + EMULTIHOP + EDOTDOT + EBADMSG + EOVERFLOW + ENOTUNIQ + EBADFD + EREMCHG + ELIBACC + ELIBBAD // 80 + ELIBSCN + ELIBMAX + ELIBEXEC + EILSEQ + ERESTART + ESTRPIPE + EUSERS + ENOTSOCK + EDESTADDRREQ + EMSGSIZE // 90 + EPROTOTYPE + ENOPROTOOPT + EPROTONOSUPPORT + ESOCKTNOSUPPORT + EOPNOTSUPP + EPFNOSUPPORT + EAFNOSUPPORT + EADDRINUSE + EADDRNOTAVAIL + ENETDOWN // 100 + ENETUNREACH + ENETRESET + ECONNABORTED + ECONNRESET + ENOBUFS + EISCONN + ENOTCONN + ESHUTDOWN + ETOOMANYREFS + ETIMEDOUT // 110 + ECONNREFUSED + EHOSTDOWN + EHOSTUNREACH + EALREADY + EINPROGRESS + ESTALE + EUCLEAN + ENOTNAM + ENAVAIL + EISNAM // 120 + EREMOTEIO + EDQUOT + ENOMEDIUM + EMEDIUMTYPE + ECANCELED + ENOKEY + EKEYEXPIRED + EKEYREVOKED + EKEYREJECTED + EOWNERDEAD // 130 + ENOTRECOVERABLE + ERFKILL + EHWPOISON ) -// Errno values from include/uapi/asm-generic/errno.h. -var ( - EDEADLK = &Errno{35, "resource deadlock would occur"} - ENAMETOOLONG = &Errno{36, "file name too long"} - ENOLCK = &Errno{37, "no record locks available"} - ENOSYS = &Errno{38, "invalid system call number"} - ENOTEMPTY = &Errno{39, "directory not empty"} - ELOOP = &Errno{40, "too many symbolic links encountered"} - EWOULDBLOCK = &Errno{EAGAIN.number, "operation would block"} - ENOMSG = &Errno{42, "no message of desired type"} - EIDRM = &Errno{43, "identifier removed"} - ECHRNG = &Errno{44, "channel number out of range"} - EL2NSYNC = &Errno{45, "level 2 not synchronized"} - EL3HLT = &Errno{46, "level 3 halted"} - EL3RST = &Errno{47, "level 3 reset"} - ELNRNG = &Errno{48, "link number out of range"} - EUNATCH = &Errno{49, "protocol driver not attached"} - ENOCSI = &Errno{50, "no CSI structure available"} - EL2HLT = &Errno{51, "level 2 halted"} - EBADE = &Errno{52, "invalid exchange"} - EBADR = &Errno{53, "invalid request descriptor"} - EXFULL = &Errno{54, "exchange full"} - ENOANO = &Errno{55, "no anode"} - EBADRQC = &Errno{56, "invalid request code"} - EBADSLT = &Errno{57, "invalid slot"} - EDEADLOCK = EDEADLK - EBFONT = &Errno{59, "bad font file format"} - ENOSTR = &Errno{60, "device not a stream"} - ENODATA = &Errno{61, "no data available"} - ETIME = &Errno{62, "timer expired"} - ENOSR = &Errno{63, "out of streams resources"} - ENONET = &Errno{64, "machine is not on the network"} - ENOPKG = &Errno{65, "package not installed"} - EREMOTE = &Errno{66, "object is remote"} - ENOLINK = &Errno{67, "link has been severed"} - EADV = &Errno{68, "advertise error"} - ESRMNT = &Errno{69, "srmount error"} - ECOMM = &Errno{70, "communication error on send"} - EPROTO = &Errno{71, "protocol error"} - EMULTIHOP = &Errno{72, "multihop attempted"} - EDOTDOT = &Errno{73, "RFS specific error"} - EBADMSG = &Errno{74, "not a data message"} - EOVERFLOW = &Errno{75, "value too large for defined data type"} - ENOTUNIQ = &Errno{76, "name not unique on network"} - EBADFD = &Errno{77, "file descriptor in bad state"} - EREMCHG = &Errno{78, "remote address changed"} - ELIBACC = &Errno{79, "can not access a needed shared library"} - ELIBBAD = &Errno{80, "accessing a corrupted shared library"} - ELIBSCN = &Errno{81, ".lib section in a.out corrupted"} - ELIBMAX = &Errno{82, "attempting to link in too many shared libraries"} - ELIBEXEC = &Errno{83, "cannot exec a shared library directly"} - EILSEQ = &Errno{84, "illegal byte sequence"} - ERESTART = &Errno{85, "interrupted system call should be restarted"} - ESTRPIPE = &Errno{86, "streams pipe error"} - EUSERS = &Errno{87, "too many users"} - ENOTSOCK = &Errno{88, "socket operation on non-socket"} - EDESTADDRREQ = &Errno{89, "destination address required"} - EMSGSIZE = &Errno{90, "message too long"} - EPROTOTYPE = &Errno{91, "protocol wrong type for socket"} - ENOPROTOOPT = &Errno{92, "protocol not available"} - EPROTONOSUPPORT = &Errno{93, "protocol not supported"} - ESOCKTNOSUPPORT = &Errno{94, "socket type not supported"} - EOPNOTSUPP = &Errno{95, "operation not supported on transport endpoint"} - EPFNOSUPPORT = &Errno{96, "protocol family not supported"} - EAFNOSUPPORT = &Errno{97, "address family not supported by protocol"} - EADDRINUSE = &Errno{98, "address already in use"} - EADDRNOTAVAIL = &Errno{99, "cannot assign requested address"} - ENETDOWN = &Errno{100, "network is down"} - ENETUNREACH = &Errno{101, "network is unreachable"} - ENETRESET = &Errno{102, "network dropped connection because of reset"} - ECONNABORTED = &Errno{103, "software caused connection abort"} - ECONNRESET = &Errno{104, "connection reset by peer"} - ENOBUFS = &Errno{105, "no buffer space available"} - EISCONN = &Errno{106, "transport endpoint is already connected"} - ENOTCONN = &Errno{107, "transport endpoint is not connected"} - ESHUTDOWN = &Errno{108, "cannot send after transport endpoint shutdown"} - ETOOMANYREFS = &Errno{109, "too many references: cannot splice"} - ETIMEDOUT = &Errno{110, "connection timed out"} - ECONNREFUSED = &Errno{111, "connection refused"} - EHOSTDOWN = &Errno{112, "host is down"} - EHOSTUNREACH = &Errno{113, "no route to host"} - EALREADY = &Errno{114, "operation already in progress"} - EINPROGRESS = &Errno{115, "operation now in progress"} - ESTALE = &Errno{116, "stale file handle"} - EUCLEAN = &Errno{117, "structure needs cleaning"} - ENOTNAM = &Errno{118, "not a XENIX named type file"} - ENAVAIL = &Errno{119, "no XENIX semaphores available"} - EISNAM = &Errno{120, "is a named type file"} - EREMOTEIO = &Errno{121, "remote I/O error"} - EDQUOT = &Errno{122, "quota exceeded"} - ENOMEDIUM = &Errno{123, "no medium found"} - EMEDIUMTYPE = &Errno{124, "wrong medium type"} - ECANCELED = &Errno{125, "operation Canceled"} - ENOKEY = &Errno{126, "required key not available"} - EKEYEXPIRED = &Errno{127, "key has expired"} - EKEYREVOKED = &Errno{128, "key has been revoked"} - EKEYREJECTED = &Errno{129, "key was rejected by service"} - EOWNERDEAD = &Errno{130, "owner died"} - ENOTRECOVERABLE = &Errno{131, "state not recoverable"} - ERFKILL = &Errno{132, "operation not possible due to RF-kill"} - EHWPOISON = &Errno{133, "memory page has hardware error"} +// errnos derived from other errnos +const ( + EWOULDBLOCK = EAGAIN + EDEADLOCK = EDEADLK ) diff --git a/pkg/sentry/devices/quotedev/BUILD b/pkg/sentry/devices/quotedev/BUILD new file mode 100644 index 000000000..d09214e3e --- /dev/null +++ b/pkg/sentry/devices/quotedev/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "quotedev", + srcs = ["quotedev.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/devices/quotedev/quotedev.go b/pkg/sentry/devices/quotedev/quotedev.go new file mode 100644 index 000000000..6114cb724 --- /dev/null +++ b/pkg/sentry/devices/quotedev/quotedev.go @@ -0,0 +1,52 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package quotedev implements a vfs.Device for /dev/gvisor_quote. +package quotedev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + quoteDevMinor = 0 +) + +// quoteDevice implements vfs.Device for /dev/gvisor_quote +// +// +stateify savable +type quoteDevice struct{} + +// Open implements vfs.Device.Open. +// TODO(b/157161182): Add support for attestation ioctls. +func (quoteDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + return nil, syserror.EIO +} + +// Register registers all devices implemented by this package in vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + return vfsObj.RegisterDevice(vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, quoteDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "gvisor_quote", + }) +} + +// CreateDevtmpfsFiles creates device special files in dev representing all +// devices implemented by this package. +func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { + return dev.CreateDeviceFile(ctx, "gvisor_quote", vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, 0666 /* mode */) +} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 7c7543f14..cf905fae4 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -65,6 +65,7 @@ var _ kernfs.Inode = (*tasksInode)(nil) func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ + "cmdline": fs.newInode(ctx, root, 0444, &cmdLineData{}), "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index e1a8b4409..045ed7a2d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -336,15 +336,6 @@ var _ dynamicInode = (*versionData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { - k := kernel.KernelFromContext(ctx) - init := k.GlobalInit() - if init == nil { - // Attempted to read before the init Task is created. This can - // only occur during startup, which should never need to read - // this file. - panic("Attempted to read version before initial Task is available") - } - // /proc/version takes the form: // // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST) @@ -364,7 +355,7 @@ func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { // FIXME(mpratt): Using Version from the init task SyscallTable // disregards the different version a task may have (e.g., in a uts // namespace). - ver := init.Leader().SyscallTable().Version + ver := kernelVersion(ctx) fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) return nil } @@ -400,3 +391,31 @@ func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error { r.GenerateProcCgroups(buf) return nil } + +// cmdLineData backs /proc/cmdline. +// +// +stateify savable +type cmdLineData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cmdLineData)(nil) + +// Generate implements vfs.DynamicByteSource.Generate. +func (*cmdLineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "BOOT_IMAGE=/vmlinuz-%s-gvisor quiet\n", kernelVersion(ctx).Release) + return nil +} + +// kernelVersion returns the kernel version. +func kernelVersion(ctx context.Context) kernel.Version { + k := kernel.KernelFromContext(ctx) + init := k.GlobalInit() + if init == nil { + // Attempted to read before the init Task is created. This can + // only occur during startup, which should never need to read + // this file. + panic("Attempted to read version before initial Task is available") + } + return init.Leader().SyscallTable().Version +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index d6f076cd6..e534fbca8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -47,6 +47,7 @@ var ( var ( tasksStaticFiles = map[string]testutil.DirentType{ + "cmdline": linux.DT_REG, "cpuinfo": linux.DT_REG, "filesystems": linux.DT_REG, "loadavg": linux.DT_REG, diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 99f036bba..1b5d5f66e 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -75,6 +75,9 @@ type machine struct { // nextID is the next vCPU ID. nextID uint32 + + // machineArchState is the architecture-specific state. + machineArchState } const ( @@ -196,12 +199,7 @@ func newMachine(vm int) (*machine, error) { m.available.L = &m.mu // Pull the maximum vCPUs. - maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) - if errno != 0 { - m.maxVCPUs = _KVM_NR_VCPUS - } else { - m.maxVCPUs = int(maxVCPUs) - } + m.getMaxVCPU() log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) @@ -427,9 +425,8 @@ func (m *machine) Get() *vCPU { } } - // Create a new vCPU (maybe). - if int(m.nextID) < m.maxVCPUs { - c := m.newVCPU() + // Get a new vCPU (maybe). + if c := m.getNewVCPU(); c != nil { c.lock() m.vCPUsByTID[tid] = c m.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index f727e61b0..9a2337654 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -63,6 +63,9 @@ func (m *machine) initArchState() error { return nil } +type machineArchState struct { +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -499,3 +502,22 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { physical) } } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + if errno != 0 { + m.maxVCPUs = _KVM_NR_VCPUS + } else { + m.maxVCPUs = int(maxVCPUs) + } +} + +// getNewVCPU create a new vCPU (maybe) +func (m *machine) getNewVCPU() *vCPU { + if int(m.nextID) < m.maxVCPUs { + c := m.newVCPU() + return c + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index cd912f922..8926b1d9f 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -17,6 +17,10 @@ package kvm import ( + "runtime" + "sync/atomic" + + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" @@ -25,6 +29,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" ) +type machineArchState struct { + //initialvCPUs is the machine vCPUs which has initialized but not used + initialvCPUs map[int]*vCPU +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -182,3 +191,30 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, return accessType, platform.ErrContextSignal } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + rmaxVCPUs := runtime.NumCPU() + smaxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + // compare the max vcpu number from runtime and syscall, use smaller one. + if errno != 0 { + m.maxVCPUs = rmaxVCPUs + } else { + if rmaxVCPUs < int(smaxVCPUs) { + m.maxVCPUs = rmaxVCPUs + } else { + m.maxVCPUs = int(smaxVCPUs) + } + } +} + +// getNewVCPU() scan for an available vCPU from initialvCPUs +func (m *machine) getNewVCPU() *vCPU { + for CID, c := range m.initialvCPUs { + if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) { + delete(m.initialvCPUs, CID) + return c + } + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 634e55ec0..92edc992b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" + ktime "gvisor.dev/gvisor/pkg/sentry/time" ) type kvmVcpuInit struct { @@ -47,6 +48,19 @@ func (m *machine) initArchState() error { uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno)) } + + // Initialize all vCPUs on ARM64, while this does not happen on x86_64. + // The reason for the difference is that ARM64 and x86_64 have different KVM timer mechanisms. + // If we create vCPU dynamically on ARM64, the timer for vCPU would mess up for a short time. + // For more detail, please refer to https://github.com/google/gvisor/issues/5739 + m.initialvCPUs = make(map[int]*vCPU) + m.mu.Lock() + for int(m.nextID) < m.maxVCPUs-1 { + c := m.newVCPU() + c.state = 0 + m.initialvCPUs[c.id] = c + } + m.mu.Unlock() return nil } @@ -174,9 +188,58 @@ func (c *vCPU) setTSC(value uint64) error { return nil } +// getTSC gets the counter Physical Counter minus Virtual Offset. +func (c *vCPU) getTSC() error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + + if err := c.getOneRegister(®); err != nil { + return err + } + + return nil +} + // setSystemTime sets the vCPU to the system time. func (c *vCPU) setSystemTime() error { - return c.setSystemTimeLegacy() + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Use get the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + // replace getTSC to another setOneRegister syscall can get more accurate value? + start := uint64(ktime.Rdtsc()) + if err := c.getTSC(); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && (current <= upperThreshold || minimum < 50) { + // Try to set the TSC + if err := c.setTSC(end + (minimum / 2)); err != nil { + return err + } + return nil + } + } } //go:nosplit @@ -203,7 +266,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { uintptr(c.fd), _KVM_GET_ONE_REG, uintptr(unsafe.Pointer(reg))); errno != 0 { - return fmt.Errorf("error setting one register: %v", errno) + return fmt.Errorf("error getting one register: %v", errno) } return nil } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d75a2879f..280563d09 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -656,7 +656,7 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr Type: linux.NLMSG_ERROR, }) m.Put(&linux.NetlinkErrorMessage{ - Error: int32(-err.ToLinux().Number()), + Error: int32(-err.ToLinux()), Header: hdr, }) } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 60ef33360..3fd22f936 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -200,11 +200,12 @@ var Metrics = tcpip.Stats{ OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."), OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."), Forwarding: tcpip.IPForwardingStats{ - Unrouteable: mustCreateMetric("/netstack/ip/forwarding/unrouteable", "Number of IP packets received which couldn't be routed and thus were not forwarded."), - ExhaustedTTL: mustCreateMetric("/netstack/ip/forwarding/exhausted_ttl", "Number of IP packets received which could not be forwarded due to an exhausted TTL."), - LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), - LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), - Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), + Unrouteable: mustCreateMetric("/netstack/ip/forwarding/unrouteable", "Number of IP packets received which couldn't be routed and thus were not forwarded."), + ExhaustedTTL: mustCreateMetric("/netstack/ip/forwarding/exhausted_ttl", "Number of IP packets received which could not be forwarded due to an exhausted TTL."), + LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), + LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), + ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."), + Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), }, }, ARP: tcpip.ARPStats{ @@ -850,7 +851,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return &optP, nil } - optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux()) return &optP, nil case linux.SO_PEERCRED: diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 9e56487a6..353f4ade0 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -80,7 +80,7 @@ func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { } ee := linux.SockExtendedErr{ - Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux()), Origin: errOriginToLinux(sockErr.Cause.Origin()), Type: sockErr.Cause.Type(), Code: sockErr.Cause.Code(), diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 79e564de6..90be24e15 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -38,7 +38,7 @@ var ( ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE) ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL) ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE) - ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), nil) + ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), linux.NOERRNO) ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT) ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE) ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS) diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index b5881ea3c..d70521f32 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -34,24 +34,19 @@ type Error struct { // linux.Errno. noTranslation bool - // errno is the linux.Errno this Error should be translated to. nil means - // that this Error should be translated to a nil linux.Errno. - errno *linux.Errno + // errno is the linux.Errno this Error should be translated to. + errno linux.Errno } // New creates a new Error and adds a translation for it. // // New must only be called at init. -func New(message string, linuxTranslation *linux.Errno) *Error { +func New(message string, linuxTranslation linux.Errno) *Error { err := &Error{message: message, errno: linuxTranslation} - if linuxTranslation == nil { - return err - } - // TODO(b/34162363): Remove this. - errno := linuxTranslation.Number() - if errno <= 0 || errno >= len(linuxBackwardsTranslations) { + errno := linuxTranslation + if errno < 0 || int(errno) >= len(linuxBackwardsTranslations) { panic(fmt.Sprint("invalid errno: ", errno)) } @@ -74,7 +69,7 @@ func New(message string, linuxTranslation *linux.Errno) *Error { // NewDynamic should only be used sparingly and not be used for static error // messages. Errors with static error messages should be declared with New as // global variables. -func NewDynamic(message string, linuxTranslation *linux.Errno) *Error { +func NewDynamic(message string, linuxTranslation linux.Errno) *Error { return &Error{message: message, errno: linuxTranslation} } @@ -87,7 +82,7 @@ func NewWithoutTranslation(message string) *Error { return &Error{message: message, noTranslation: true} } -func newWithHost(message string, linuxTranslation *linux.Errno, hostErrno unix.Errno) *Error { +func newWithHost(message string, linuxTranslation linux.Errno, hostErrno unix.Errno) *Error { e := New(message, linuxTranslation) addLinuxHostTranslation(hostErrno, e) return e @@ -119,10 +114,10 @@ func (e *Error) ToError() error { if e.noTranslation { panic(fmt.Sprintf("error %q does not support translation", e.message)) } - if e.errno == nil { + errno := int(e.errno) + if errno == linux.NOERRNO { return nil } - errno := e.errno.Number() if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok { panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno)) } @@ -131,7 +126,7 @@ func (e *Error) ToError() error { // ToLinux converts the Error to a Linux ABI error that can be returned to the // application. -func (e *Error) ToLinux() *linux.Errno { +func (e *Error) ToLinux() linux.Errno { if e.noTranslation { panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message)) } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 12c39dfa3..18e6cc3cd 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1607,6 +1607,17 @@ func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { } } +// IPv6UnknownOption validates that an extension header option is the +// unknown header option. +func IPv6UnknownOption() IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + _, ok := opt.(*header.IPv6UnknownExtHdrOption) + if !ok { + t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt) + } + } +} + // IgnoreCmpPath returns a cmp.Option that ignores listed field paths. func IgnoreCmpPath(paths ...string) cmp.Option { ignores := map[string]struct{}{} diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 6905b9ccb..a72eb1aad 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -47,7 +47,7 @@ go_test( library = ":arp", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index e867b3c3f..0df39ae81 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go index 50fabfd79..d3577b377 100644 --- a/pkg/tcpip/network/internal/ip/errors.go +++ b/pkg/tcpip/network/internal/ip/errors.go @@ -34,13 +34,13 @@ func (*ErrTTLExceeded) isForwardingError() {} func (*ErrTTLExceeded) String() string { return "ttl exceeded" } -// ErrIPOptProblem indicates the received packet had a problem with an IP -// option. -type ErrIPOptProblem struct{} +// ErrParameterProblem indicates the received packet had a problem with an IP +// parameter. +type ErrParameterProblem struct{} -func (*ErrIPOptProblem) isForwardingError() {} +func (*ErrParameterProblem) isForwardingError() {} -func (*ErrIPOptProblem) String() string { return "ip option problem" } +func (*ErrParameterProblem) String() string { return "parameter problem" } // ErrLinkLocalSourceAddress indicates the received packet had a link-local // source address. diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 392f0b0c7..68b8b550e 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -16,7 +16,7 @@ package ip import "gvisor.dev/gvisor/pkg/tcpip" -// LINT.IfChange(MultiCounterIPStats) +// LINT.IfChange(MultiCounterIPForwardingStats) // MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter // may have several versions. @@ -38,11 +38,30 @@ type MultiCounterIPForwardingStats struct { // because they contained a link-local destination address. LinkLocalDestination tcpip.MultiCounterStat + // ExtensionHeaderProblem is the number of IP packets which were dropped + // because of a problem encountered when processing an IPv6 extension + // header. + ExtensionHeaderProblem tcpip.MultiCounterStat + // Errors is the number of IP packets received which could not be // successfully forwarded. Errors tcpip.MultiCounterStat } +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { + m.Unrouteable.Init(a.Unrouteable, b.Unrouteable) + m.Errors.Init(a.Errors, b.Errors) + m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource) + m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination) + m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) + m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) +} + +// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) + +// LINT.IfChange(MultiCounterIPStats) + // MultiCounterIPStats holds IP statistics, each counter may have several // versions. type MultiCounterIPStats struct { @@ -120,15 +139,6 @@ type MultiCounterIPStats struct { } // Init sets internal counters to track a and b counters. -func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { - m.Unrouteable.Init(a.Unrouteable, b.Unrouteable) - m.Errors.Init(a.Errors, b.Errors) - m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource) - m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination) - m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) -} - -// Init sets internal counters to track a and b counters. func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index 1c4f583c7..cec3e62c4 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -4,10 +4,7 @@ package(licenses = ["notice"]) go_library( name = "testutil", - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], + srcs = ["testutil.go"], visibility = [ "//pkg/tcpip/network/arp:__pkg__", "//pkg/tcpip/network/internal/fragmentation:__pkg__", diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index e2cf24b67..605e9ef8d 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -19,8 +19,6 @@ package testutil import ( "fmt" "math/rand" - "reflect" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -129,69 +127,3 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi } return pkt } - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7ee0495d9..c90974693 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -62,7 +62,7 @@ go_test( library = ":ipv4", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b11e56c6a..4031032d0 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -645,7 +645,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { forwarding: true, }, pkt) } - return &ip.ErrIPOptProblem{} + return &ip.ErrParameterProblem{} } copied := copy(opts, newOpts) if copied != len(newOpts) { @@ -827,7 +827,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) stats.ip.Forwarding.ExhaustedTTL.Increment() case *ip.ErrNoRoute: stats.ip.Forwarding.Unrouteable.Increment() - case *ip.ErrIPOptProblem: + case *ip.ErrParameterProblem: e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() default: @@ -990,8 +990,8 @@ func (e *endpoint) Close() { // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) if err == nil { @@ -1002,8 +1002,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.mu.addressableEndpointState.RemovePermanentAddress(addr) } @@ -1016,8 +1016,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() loopback := e.nic.IsLoopback() return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index a637f9d50..d1f9e3cf5 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index db998e83e..f99cbf8f3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -45,6 +45,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ebb0b73df..247a07dc2 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -984,11 +984,15 @@ type icmpReasonParameterProblem struct { // packet if the field in error is beyond what can fit // in the maximum size of an ICMPv6 error message. pointer uint32 + + // forwarding indicates that the problem arose while we were trying to forward + // a packet. + forwarding bool } func (*icmpReasonParameterProblem) isICMPReason() {} -func (*icmpReasonParameterProblem) isForwarding() bool { - return false +func (p *icmpReasonParameterProblem) isForwarding() bool { + return p.forwarding } // icmpReasonPortUnreachable is an error where the transport protocol has no diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 659057fa7..029d5f51b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -941,6 +941,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } + // Check extension headers for any errors requiring action during forwarding. + if err := e.processExtensionHeaders(h, pkt, true /* forwarding */); err != nil { + return &ip.ErrParameterProblem{} + } + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) switch err.(type) { case nil: @@ -1084,6 +1089,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) e.stats.ip.Forwarding.ExhaustedTTL.Increment() case *ip.ErrNoRoute: e.stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrParameterProblem: + e.stats.ip.Forwarding.ExtensionHeaderProblem.Increment() default: panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) } @@ -1091,6 +1098,28 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) return } + // iptables filtering. All packets that reach here are intended for + // this machine and need not be forwarded. + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + // iptables is telling us to drop the packet. + stats.IPTablesInputDropped.Increment() + return + } + + // Any returned error is only useful for terminating execution early, but + // we have nothing left to do, so we can drop it. + _ = e.processExtensionHeaders(h, pkt, false /* forwarding */) +} + +// processExtensionHeaders processes the extension headers in the given packet. +// Returns an error if the processing of a header failed or if the packet should +// be discarded. +func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffer, forwarding bool) error { + stats := e.stats.ip + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // Create a VV to parse the packet. We don't plan to modify anything here. // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). @@ -1101,15 +1130,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) vv.AppendViews(pkt.Data().Views()) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) - // iptables filtering. All packets that reach here are intended for - // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { - // iptables is telling us to drop the packet. - stats.IPTablesInputDropped.Increment() - return - } - var ( hasFragmentHeader bool routerAlert *header.IPv6RouterAlertOption @@ -1122,22 +1142,41 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) extHdr, done, err := it.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break } + // As per RFC 8200, section 4: + // + // Extension headers (except for the Hop-by-Hop Options header) are + // not processed, inserted, or deleted by any node along a packet's + // delivery path until the packet reaches the node identified in the + // Destination Address field of the IPv6 header. + // + // Furthermore, as per RFC 8200 section 4.1, the Hop By Hop extension + // header is restricted to appear first in the list of extension headers. + // + // Therefore, we can immediately return once we hit any header other + // than the Hop-by-Hop header while forwarding a packet. + if forwarding { + if _, ok := extHdr.(header.IPv6HopByHopOptionsExtHdr); !ok { + return nil + } + } + switch extHdr := extHdr.(type) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { _ = e.protocol.returnError(&icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: previousHeaderStart, + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found Hop-by-Hop header = %#v with non-zero previous header offset = %d", extHdr, previousHeaderStart) } optsIt := extHdr.Iter() @@ -1146,7 +1185,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1161,7 +1200,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // There MUST only be one option of this type, regardless of // value, per Hop-by-Hop header. stats.MalformedPacketsReceived.Increment() - return + return fmt.Errorf("found multiple Router Alert options (%#v, %#v)", opt, routerAlert) } routerAlert = opt stats.OptionRouterAlertReceived.Increment() @@ -1169,10 +1208,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown Hop-by-Hop header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1187,10 +1226,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %#v", opt)) } } } @@ -1212,8 +1252,13 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), + // For the sake of consistency, we're using the value of `forwarding` + // here, even though it should always be false if we've reached this + // point. If `forwarding` is true here, we're executing undefined + // behavior no matter what. + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unrecognized routing type with non-zero segments left in header = %#v", extHdr) } case header.IPv6FragmentExtHdr: @@ -1248,7 +1293,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if done { break @@ -1276,7 +1321,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) default: stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("known extension header = %#v present after fragment header in a non-initial fragment", lastHdr) } } @@ -1285,7 +1330,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Drop the packet as it's marked as a fragment but has no payload. stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("fragment has no payload") } // As per RFC 2460 Section 4.5: @@ -1303,7 +1348,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, }, pkt) - return + return fmt.Errorf("found fragment length = %d that is not a multiple of 8 octets", fragmentPayloadLen) } // The packet is a fragment, let's try to reassemble it. @@ -1317,14 +1362,15 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Parameter Problem, Code 0, message should be sent to the source of // the fragment, pointing to the Fragment Offset field of the fragment // packet. - if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { + lengthAfterReassembly := int(start) + fragmentPayloadLen + if lengthAfterReassembly > header.IPv6MaximumPayloadSize { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, }, pkt) - return + return fmt.Errorf("determined that reassembled packet length = %d would exceed allowed length = %d", lengthAfterReassembly, header.IPv6MaximumPayloadSize) } // Note that pkt doesn't have its transport header set after reassembly, @@ -1346,7 +1392,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if ready { @@ -1368,7 +1414,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1379,10 +1425,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown destination header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1399,9 +1445,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, }, pkt) - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %#v", opt)) } } @@ -1432,6 +1478,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) + return fmt.Errorf("destination port unreachable") case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -1463,6 +1510,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownHeader, pointer: prevHdrIDOffset, }, pkt) + return fmt.Errorf("transport protocol unreachable") default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -1476,6 +1524,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) } } + return nil } // Close cleans up resources associated with the endpoint. @@ -1497,8 +1546,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1539,8 +1588,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { @@ -1617,8 +1666,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 4fbe39528..8ebca735b 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -31,8 +31,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -2603,7 +2604,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -2802,9 +2803,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt.Clone() - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2858,7 +2859,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -2868,14 +2869,14 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -2980,8 +2981,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -3025,6 +3026,7 @@ func TestForwarding(t *testing.T) { tests := []struct { name string + extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) TTL uint8 expectErrorICMP bool expectPacketForwarded bool @@ -3036,6 +3038,7 @@ func TestForwarding(t *testing.T) { expectPacketUnrouteableError bool expectLinkLocalSourceError bool expectLinkLocalDestError bool + expectExtensionHeaderError bool }{ { name: "TTL of zero", @@ -3108,6 +3111,158 @@ func TestForwarding(t *testing.T) { destAddr: remoteIPv6Addr2, expectLinkLocalSourceError: true, }, + { + name: "Hopbyhop with unknown option skippable action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption())) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with unknown option discard action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with router alert option", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD))) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with two router alert options", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, } for _, test := range tests { @@ -3150,7 +3305,17 @@ func TestForwarding(t *testing.T) { t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) } - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) + transportProtocol := header.ICMPv6ProtocolNumber + extHdrBytes := []byte{} + extHdrChecker := checker.IPv6ExtHdr() + if test.extHdr != nil { + nextHdrID := hopByHopExtHdrID + extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber)) + transportProtocol = tcpip.TransportProtocolNumber(nextHdrID) + } + extHdrLen := len(extHdrBytes) + + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize + extHdrLen) icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) @@ -3162,10 +3327,11 @@ func TestForwarding(t *testing.T) { Src: test.sourceAddr, Dst: test.destAddr, })) + copy(hdr.Prepend(extHdrLen), extHdrBytes) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: header.ICMPv6ProtocolNumber, + TransportProtocol: transportProtocol, HopLimit: test.TTL, SrcAddr: test.sourceAddr, DstAddr: test.destAddr, @@ -3205,10 +3371,11 @@ func TestForwarding(t *testing.T) { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), + extHdrChecker, checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6EchoRequest), checker.ICMPv6Code(header.ICMPv6UnusedCode), @@ -3249,6 +3416,10 @@ func TestForwarding(t *testing.T) { if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) } + + if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want) + } }) } } diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index e5590ecc0..ce9cebdaa 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -440,33 +440,54 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad // Regardless how the address was obtained, it will be acquired before it is // returned. func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { - a.mu.Lock() - defer a.mu.Unlock() + lookup := func() *addressState { + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } - if addrState, ok := a.mu.endpoints[localAddr]; ok { - if !addrState.IsAssigned(allowTemp) { - return nil - } + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } - if !addrState.IncRef() { - panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + return addrState } - return addrState - } - - if f != nil { - for _, addrState := range a.mu.endpoints { - if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { - return addrState + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } } } + return nil + } + // Avoid exclusive lock on mu unless we need to add a new address. + a.mu.RLock() + ep := lookup() + a.mu.RUnlock() + + if ep != nil { + return ep } if !allowTemp { return nil } + // Acquire state lock in exclusive mode as we need to add a new temporary + // endpoint. + a.mu.Lock() + defer a.mu.Unlock() + + // Do the lookup again in case another goroutine added the address in the time + // we released and acquired the lock. + ep = lookup() + if ep != nil { + return ep + } + + // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) if err != nil { @@ -475,6 +496,7 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // expect no error. panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) } + // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 7e7415df4..f9acd4bb8 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1530,9 +1530,10 @@ type IGMPStats struct { // IPForwardingStats collects stats related to IP forwarding (both v4 and v6). type IPForwardingStats struct { + // LINT.IfChange(IPForwardingStats) + // Unrouteable is the number of IP packets received which were dropped - // because the netstack could not construct a route to their - // destination. + // because a route to their destination could not be constructed. Unrouteable *StatCounter // ExhaustedTTL is the number of IP packets received which were dropped @@ -1547,9 +1548,16 @@ type IPForwardingStats struct { // because they contained a link-local destination address. LinkLocalDestination *StatCounter + // ExtensionHeaderProblem is the number of IP packets which were dropped + // because of a problem encountered when processing an IPv6 extension + // header. + ExtensionHeaderProblem *StatCounter + // Errors is the number of IP packets received which could not be // successfully forwarded. Errors *StatCounter + + // LINT.ThenChange(network/internal/ip/stats.go:multiCounterIPForwardingStats) } // IPStats collects IP-specific stats (both v4 and v6). diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD index 472545a5d..02ee86ff1 100644 --- a/pkg/tcpip/testutil/BUILD +++ b/pkg/tcpip/testutil/BUILD @@ -5,7 +5,10 @@ package(licenses = ["notice"]) go_library( name = "testutil", testonly = True, - srcs = ["testutil.go"], + srcs = [ + "testutil.go", + "testutil_unsafe.go", + ], visibility = ["//visibility:public"], deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index 1aaed590f..f84d399fb 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -18,6 +18,8 @@ package testutil import ( "fmt" "net" + "reflect" + "strings" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -41,3 +43,69 @@ func MustParse6(addr string) tcpip.Address { } return tcpip.Address(ip) } + +func checkFieldCounts(ref, multi reflect.Value) error { + refTypeName := ref.Type().Name() + multiTypeName := multi.Type().Name() + refNumField := ref.NumField() + multiNumField := multi.NumField() + + if refNumField != multiNumField { + return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) + } + + return nil +} + +func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { + s, ok := ref.Addr().Interface().(**tcpip.StatCounter) + if !ok { + return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) + } + + // The field names are expected to match (case insensitive). + if !strings.EqualFold(refName, multiName) { + return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) + } + + base := (*s).Value() + m.Increment() + if (*s).Value() != base+1 { + return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) + } + + return nil +} + +// ValidateMultiCounterStats verifies that every counter stored in multi is +// correctly tracking its counterpart in the given counters. +func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { + for _, c := range counters { + if err := checkFieldCounts(c, multi); err != nil { + return err + } + } + + for i := 0; i < multi.NumField(); i++ { + multiName := multi.Type().Field(i).Name + multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) + + if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { + for _, c := range counters { + if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { + return err + } + } + } else { + var countersNextField []reflect.Value + for _, c := range counters { + countersNextField = append(countersNextField, c.Field(i)) + } + if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go index 5ff764800..5ff764800 100644 --- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go +++ b/pkg/tcpip/testutil/testutil_unsafe.go |