summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD4
-rw-r--r--pkg/abi/linux/context.go36
-rw-r--r--pkg/abi/linux/errno/BUILD9
-rw-r--r--pkg/abi/linux/errno/errno.go (renamed from pkg/abi/linux/errors.go)14
-rw-r--r--pkg/abi/linux/ioctl_tun.go3
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go1
-rw-r--r--pkg/abi/linux/signal.go344
-rw-r--r--pkg/abi/linux/xattr.go6
-rw-r--r--pkg/control/server/server.go5
-rw-r--r--pkg/crypto/crypto_stdlib.go8
-rw-r--r--pkg/errors/BUILD10
-rw-r--r--pkg/errors/errors.go40
-rw-r--r--pkg/errors/linuxerr/BUILD26
-rw-r--r--pkg/errors/linuxerr/linuxerr.go347
-rw-r--r--pkg/errors/linuxerr/linuxerr_test.go306
-rw-r--r--pkg/flipcall/BUILD4
-rw-r--r--pkg/flipcall/flipcall.go7
-rw-r--r--pkg/flipcall/packet_window.go (renamed from pkg/flipcall/packet_window_allocator.go)0
-rw-r--r--pkg/iovec/BUILD18
-rw-r--r--pkg/iovec/iovec.go71
-rw-r--r--pkg/iovec/iovec_test.go120
-rw-r--r--pkg/marshal/primitive/BUILD2
-rw-r--r--pkg/marshal/primitive/primitive.go25
-rw-r--r--pkg/memutil/BUILD6
-rw-r--r--pkg/memutil/memfd_linux_unsafe.go (renamed from pkg/memutil/memutil_unsafe.go)1
-rw-r--r--pkg/memutil/memutil.go (renamed from pkg/tcpip/time.s)3
-rw-r--r--pkg/memutil/mmap.go (renamed from pkg/flipcall/packet_window_mmap_amd64.go)18
-rw-r--r--pkg/merkletree/merkletree.go18
-rw-r--r--pkg/merkletree/merkletree_test.go40
-rw-r--r--pkg/metric/metric.go74
-rw-r--r--pkg/metric/metric_test.go18
-rw-r--r--pkg/p9/BUILD3
-rw-r--r--pkg/p9/handlers.go42
-rw-r--r--pkg/p9/server.go19
-rw-r--r--pkg/refs/refcounter.go6
-rw-r--r--pkg/ring0/kernel_amd64.go22
-rw-r--r--pkg/safecopy/BUILD1
-rw-r--r--pkg/safecopy/safecopy_unsafe.go14
-rw-r--r--pkg/seccomp/seccomp.go12
-rw-r--r--pkg/sentry/arch/BUILD5
-rw-r--r--pkg/sentry/arch/arch.go14
-rw-r--r--pkg/sentry/arch/arch_aarch64.go5
-rw-r--r--pkg/sentry/arch/arch_x86.go5
-rw-r--r--pkg/sentry/arch/signal.go276
-rw-r--r--pkg/sentry/arch/signal_act.go83
-rw-r--r--pkg/sentry/arch/signal_amd64.go26
-rw-r--r--pkg/sentry/arch/signal_arm64.go26
-rw-r--r--pkg/sentry/arch/signal_info.go66
-rw-r--r--pkg/sentry/arch/signal_stack.go68
-rw-r--r--pkg/sentry/arch/stack.go2
-rw-r--r--pkg/sentry/control/proc.go9
-rw-r--r--pkg/sentry/devices/tundev/BUILD1
-rw-r--r--pkg/sentry/devices/tundev/tundev.go3
-rw-r--r--pkg/sentry/fs/BUILD2
-rw-r--r--pkg/sentry/fs/attr.go14
-rw-r--r--pkg/sentry/fs/copy_up.go5
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/dev.go8
-rw-r--r--pkg/sentry/fs/dev/net_tun.go3
-rw-r--r--pkg/sentry/fs/dirent.go5
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD2
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go3
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_test.go6
-rw-r--r--pkg/sentry/fs/file_overlay.go7
-rw-r--r--pkg/sentry/fs/fsutil/BUILD1
-rw-r--r--pkg/sentry/fs/fsutil/file.go31
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go21
-rw-r--r--pkg/sentry/fs/fsutil/inode.go7
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go27
-rw-r--r--pkg/sentry/fs/gofer/BUILD1
-rw-r--r--pkg/sentry/fs/gofer/cache_policy.go2
-rw-r--r--pkg/sentry/fs/gofer/file.go17
-rw-r--r--pkg/sentry/fs/gofer/inode.go18
-rw-r--r--pkg/sentry/fs/gofer/path.go25
-rw-r--r--pkg/sentry/fs/host/BUILD3
-rw-r--r--pkg/sentry/fs/host/socket.go3
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go4
-rw-r--r--pkg/sentry/fs/host/tty.go5
-rw-r--r--pkg/sentry/fs/host/util.go4
-rw-r--r--pkg/sentry/fs/inode_overlay.go7
-rw-r--r--pkg/sentry/fs/inode_overlay_test.go5
-rw-r--r--pkg/sentry/fs/inotify.go9
-rw-r--r--pkg/sentry/fs/mock.go3
-rw-r--r--pkg/sentry/fs/mounts.go3
-rw-r--r--pkg/sentry/fs/overlay.go4
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/exec_args.go4
-rw-r--r--pkg/sentry/fs/proc/net.go4
-rw-r--r--pkg/sentry/fs/proc/proc.go5
-rw-r--r--pkg/sentry/fs/proc/sys.go22
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/proc/task.go5
-rw-r--r--pkg/sentry/fs/proc/uid_gid_map.go10
-rw-r--r--pkg/sentry/fs/proc/uptime.go4
-rw-r--r--pkg/sentry/fs/splice.go5
-rw-r--r--pkg/sentry/fs/timerfd/BUILD1
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go5
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go2
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go22
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go31
-rw-r--r--pkg/sentry/fs/tty/BUILD1
-rw-r--r--pkg/sentry/fs/tty/fs.go4
-rw-r--r--pkg/sentry/fs/user/BUILD1
-rw-r--r--pkg/sentry/fs/user/path.go5
-rw-r--r--pkg/sentry/fs/user/user_test.go5
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/base.go25
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go12
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD1
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go3
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go3
-rw-r--r--pkg/sentry/fsimpl/devpts/replica.go3
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go3
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go8
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go5
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go3
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go3
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD2
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_test.go4
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go9
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go51
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/regular_file.go11
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go51
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go56
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go60
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go10
-rw-r--r--pkg/sentry/fsimpl/host/BUILD2
-rw-r--r--pkg/sentry/fsimpl/host/host.go19
-rw-r--r--pkg/sentry/fsimpl/host/socket.go3
-rw-r--r--pkg/sentry/fsimpl/host/socket_iovec.go4
-rw-r--r--pkg/sentry/fsimpl/host/tty.go5
-rw-r--r--pkg/sentry/fsimpl/host/util.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go28
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go5
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD1
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go3
-rw-r--r--pkg/sentry/fsimpl/overlay/directory.go7
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go35
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go21
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go13
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD3
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go11
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go15
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go6
-rw-r--r--pkg/sentry/fsimpl/proc/yama.go6
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD1
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go3
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go3
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD1
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go10
-rw-r--r--pkg/sentry/fsimpl/timerfd/BUILD1
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go6
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go30
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go15
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go72
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD3
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go40
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go71
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go16
-rw-r--r--pkg/sentry/fsmetric/fsmetric.go1
-rw-r--r--pkg/sentry/hostfd/hostfd_linux.go9
-rw-r--r--pkg/sentry/hostfd/hostfd_unsafe.go17
-rw-r--r--pkg/sentry/kernel/BUILD7
-rw-r--r--pkg/sentry/kernel/auth/BUILD3
-rw-r--r--pkg/sentry/kernel/auth/credentials.go11
-rw-r--r--pkg/sentry/kernel/auth/id_map.go21
-rw-r--r--pkg/sentry/kernel/cgroup.go29
-rw-r--r--pkg/sentry/kernel/fasync/BUILD3
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go9
-rw-r--r--pkg/sentry/kernel/fd_table.go4
-rw-r--r--pkg/sentry/kernel/futex/BUILD3
-rw-r--r--pkg/sentry/kernel/futex/futex.go5
-rw-r--r--pkg/sentry/kernel/kcov.go13
-rw-r--r--pkg/sentry/kernel/kernel.go80
-rw-r--r--pkg/sentry/kernel/pending_signals.go9
-rw-r--r--pkg/sentry/kernel/pending_signals_state.go6
-rw-r--r--pkg/sentry/kernel/pipe/BUILD2
-rw-r--r--pkg/sentry/kernel/pipe/node.go3
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go3
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go3
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go8
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go5
-rw-r--r--pkg/sentry/kernel/posixtimer.go22
-rw-r--r--pkg/sentry/kernel/ptrace.go36
-rw-r--r--pkg/sentry/kernel/rseq.go21
-rw-r--r--pkg/sentry/kernel/seccomp.go6
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD1
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go23
-rw-r--r--pkg/sentry/kernel/sessions.go8
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go9
-rw-r--r--pkg/sentry/kernel/signal.go15
-rw-r--r--pkg/sentry/kernel/signal_handlers.go17
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD1
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go3
-rw-r--r--pkg/sentry/kernel/task.go19
-rw-r--r--pkg/sentry/kernel/task_acct.go6
-rw-r--r--pkg/sentry/kernel/task_block.go3
-rw-r--r--pkg/sentry/kernel/task_cgroup.go8
-rw-r--r--pkg/sentry/kernel/task_clone.go15
-rw-r--r--pkg/sentry/kernel/task_context.go5
-rw-r--r--pkg/sentry/kernel/task_exec.go3
-rw-r--r--pkg/sentry/kernel/task_exit.go43
-rw-r--r--pkg/sentry/kernel/task_identity.go19
-rw-r--r--pkg/sentry/kernel/task_image.go4
-rw-r--r--pkg/sentry/kernel/task_sched.go14
-rw-r--r--pkg/sentry/kernel/task_signals.go131
-rw-r--r--pkg/sentry/kernel/task_start.go3
-rw-r--r--pkg/sentry/kernel/task_syscall.go6
-rw-r--r--pkg/sentry/kernel/task_usermem.go5
-rw-r--r--pkg/sentry/kernel/thread_group.go20
-rw-r--r--pkg/sentry/kernel/time/BUILD2
-rw-r--r--pkg/sentry/kernel/time/time.go25
-rw-r--r--pkg/sentry/kernel/timekeeper.go42
-rw-r--r--pkg/sentry/kernel/timekeeper_test.go4
-rw-r--r--pkg/sentry/loader/BUILD2
-rw-r--r--pkg/sentry/loader/elf.go7
-rw-r--r--pkg/sentry/loader/interpreter.go2
-rw-r--r--pkg/sentry/loader/loader.go3
-rw-r--r--pkg/sentry/loader/vdso.go3
-rw-r--r--pkg/sentry/mm/BUILD3
-rw-r--r--pkg/sentry/mm/aio_context.go13
-rw-r--r--pkg/sentry/mm/mm_test.go8
-rw-r--r--pkg/sentry/mm/shm.go6
-rw-r--r--pkg/sentry/mm/special_mappable.go5
-rw-r--r--pkg/sentry/mm/syscalls.go59
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go18
-rw-r--r--pkg/sentry/platform/kvm/BUILD3
-rw-r--r--pkg/sentry/platform/kvm/address_space.go9
-rw-r--r--pkg/sentry/platform/kvm/address_space_amd64.go (renamed from pkg/flipcall/packet_window_mmap_arm64.go)19
-rw-r--r--pkg/sentry/platform/kvm/address_space_arm64.go (renamed from pkg/tcpip/transport/tcp/rcv_state.go)18
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go23
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go29
-rw-r--r--pkg/sentry/platform/kvm/context.go7
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go33
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go36
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go12
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go21
-rw-r--r--pkg/sentry/platform/platform.go6
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go3
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go2
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go13
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go35
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go194
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go18
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go16
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go9
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go9
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go22
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go24
-rw-r--r--pkg/sentry/socket/netfilter/targets.go5
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go5
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go5
-rw-r--r--pkg/sentry/socket/netlink/BUILD2
-rw-r--r--pkg/sentry/socket/netlink/socket.go6
-rw-r--r--pkg/sentry/socket/netstack/BUILD2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go192
-rw-r--r--pkg/sentry/socket/netstack/stack.go9
-rw-r--r--pkg/sentry/socket/netstack/tun.go6
-rw-r--r--pkg/sentry/socket/socket.go2
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/unix.go5
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/state/BUILD2
-rw-r--r--pkg/sentry/state/state.go8
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/clone.go46
-rw-r--r--pkg/sentry/strace/linux64_amd64.go2
-rw-r--r--pkg/sentry/strace/linux64_arm64.go2
-rw-r--r--pkg/sentry/strace/mmap.go92
-rw-r--r--pkg/sentry/strace/open.go42
-rw-r--r--pkg/sentry/strace/signal.go4
-rw-r--r--pkg/sentry/strace/strace.go4
-rw-r--r--pkg/sentry/strace/syscalls.go6
-rw-r--r--pkg/sentry/syscalls/BUILD1
-rw-r--r--pkg/sentry/syscalls/epoll.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/error.go25
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go5
-rw-r--r--pkg/sentry/syscalls/linux/sigset.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go27
-rw-r--r--pkg/sentry/syscalls/linux/sys_capability.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_eventfd.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go113
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go9
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_identity.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_inotify.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_membarrier.go25
-rw-r--r--pkg/sentry/syscalls/linux/sys_mempolicy.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go26
-rw-r--r--pkg/sentry/syscalls/linux/sys_mount.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go17
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go25
-rw-r--r--pkg/sentry/syscalls/linux/sys_random.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go19
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go9
-rw-r--r--pkg/sentry/syscalls/linux/sys_rseq.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_rusage.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_sched.go15
-rw-r--r--pkg/sentry/syscalls/linux/sys_seccomp.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go43
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go68
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go23
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_syslog.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go41
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go87
-rw-r--r--pkg/sentry/syscalls/linux/sys_timerfd.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_amd64.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go13
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go9
-rw-r--r--pkg/sentry/syscalls/linux/timespec.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go18
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/eventfd.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go6
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go21
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go8
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go8
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/inotify.go7
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/lock.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/memfd.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go6
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go8
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go7
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go21
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go35
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go30
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/signal.go8
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go45
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go31
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go16
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/timerfd.go11
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go13
-rw-r--r--pkg/sentry/time/BUILD5
-rw-r--r--pkg/sentry/time/calibrated_clock.go4
-rw-r--r--pkg/sentry/time/sampler.go7
-rw-r--r--pkg/sentry/time/sampler_amd64.go26
-rw-r--r--pkg/sentry/time/sampler_arm64.go42
-rw-r--r--pkg/sentry/time/tsc_arm64.s6
-rw-r--r--pkg/sentry/usage/memory.go2
-rw-r--r--pkg/sentry/usage/memory_unsafe.go6
-rw-r--r--pkg/sentry/vfs/BUILD2
-rw-r--r--pkg/sentry/vfs/anonfs.go3
-rw-r--r--pkg/sentry/vfs/file_description.go9
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go15
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go9
-rw-r--r--pkg/sentry/vfs/inotify.go9
-rw-r--r--pkg/sentry/vfs/memxattr/BUILD1
-rw-r--r--pkg/sentry/vfs/memxattr/xattr.go33
-rw-r--r--pkg/sentry/vfs/mount.go11
-rw-r--r--pkg/sentry/vfs/vfs.go35
-rw-r--r--pkg/sentry/watchdog/watchdog.go13
-rw-r--r--pkg/shim/BUILD1
-rw-r--r--pkg/shim/debug.go48
-rw-r--r--pkg/shim/proc/BUILD2
-rw-r--r--pkg/shim/proc/deleted_state.go30
-rw-r--r--pkg/shim/proc/exec.go106
-rw-r--r--pkg/shim/proc/exec_state.go84
-rw-r--r--pkg/shim/proc/init.go84
-rw-r--r--pkg/shim/proc/init_state.go175
-rw-r--r--pkg/shim/proc/proc.go18
-rw-r--r--pkg/shim/runsc/runsc.go87
-rw-r--r--pkg/shim/service.go73
-rw-r--r--pkg/shim/utils/BUILD14
-rw-r--r--pkg/shim/utils/errors.go74
-rw-r--r--pkg/shim/utils/errors_test.go50
-rw-r--r--pkg/sync/BUILD40
-rw-r--r--pkg/sync/README.md2
-rw-r--r--pkg/sync/atomicptr/BUILD (renamed from pkg/sync/atomicptrtest/BUILD)13
-rw-r--r--pkg/sync/atomicptr/atomicptr_test.go (renamed from pkg/sync/atomicptrtest/atomicptr_test.go)0
-rw-r--r--pkg/sync/atomicptr/generic_atomicptr_unsafe.go (renamed from pkg/sync/generic_atomicptr_unsafe.go)0
-rw-r--r--pkg/sync/atomicptrmap/BUILD (renamed from pkg/sync/atomicptrmaptest/BUILD)25
-rw-r--r--pkg/sync/atomicptrmap/atomicptrmap.go (renamed from pkg/sync/atomicptrmaptest/atomicptrmap.go)0
-rw-r--r--pkg/sync/atomicptrmap/atomicptrmap_test.go (renamed from pkg/sync/atomicptrmaptest/atomicptrmap_test.go)0
-rw-r--r--pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go (renamed from pkg/sync/generic_atomicptrmap_unsafe.go)0
-rw-r--r--pkg/sync/seqatomic/BUILD (renamed from pkg/sync/seqatomictest/BUILD)17
-rw-r--r--pkg/sync/seqatomic/generic_seqatomic_unsafe.go (renamed from pkg/sync/generic_seqatomic_unsafe.go)0
-rw-r--r--pkg/sync/seqatomic/seqatomic_test.go (renamed from pkg/sync/seqatomictest/seqatomic_test.go)0
-rw-r--r--pkg/syserr/BUILD4
-rw-r--r--pkg/syserr/netstack.go56
-rw-r--r--pkg/syserr/syserr.go300
-rw-r--r--pkg/syserror/BUILD11
-rw-r--r--pkg/syserror/syserror.go3
-rw-r--r--pkg/syserror/syserror_test.go132
-rw-r--r--pkg/tcpip/BUILD19
-rw-r--r--pkg/tcpip/checker/checker.go14
-rw-r--r--pkg/tcpip/faketime/faketime.go27
-rw-r--r--pkg/tcpip/faketime/faketime_test.go4
-rw-r--r--pkg/tcpip/header/checksum.go62
-rw-r--r--pkg/tcpip/header/checksum_test.go203
-rw-r--r--pkg/tcpip/header/interfaces.go38
-rw-r--r--pkg/tcpip/header/ipv4.go36
-rw-r--r--pkg/tcpip/header/ndp_options.go11
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go75
-rw-r--r--pkg/tcpip/header/ndp_test.go91
-rw-r--r--pkg/tcpip/header/tcp.go39
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/link/channel/channel.go2
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go221
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go61
-rw-r--r--pkg/tcpip/link/tun/BUILD1
-rw-r--r--pkg/tcpip/link/tun/device.go9
-rw-r--r--pkg/tcpip/network/arp/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go129
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation.go2
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go24
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler.go22
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go16
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go2
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go5
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD1
-rw-r--r--pkg/tcpip/network/ip_test.go40
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go33
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go100
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go101
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go35
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go41
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go93
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go149
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go260
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go104
-rw-r--r--pkg/tcpip/ports/ports.go10
-rw-r--r--pkg/tcpip/ports/ports_test.go7
-rw-r--r--pkg/tcpip/socketops.go15
-rw-r--r--pkg/tcpip/stack/BUILD3
-rw-r--r--pkg/tcpip/stack/conntrack.go65
-rw-r--r--pkg/tcpip/stack/forwarding_test.go62
-rw-r--r--pkg/tcpip/stack/iptables.go9
-rw-r--r--pkg/tcpip/stack/iptables_targets.go105
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go1813
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go2
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go330
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go42
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go231
-rw-r--r--pkg/tcpip/stack/nic.go80
-rw-r--r--pkg/tcpip/stack/nic_stats.go74
-rw-r--r--pkg/tcpip/stack/nic_test.go32
-rw-r--r--pkg/tcpip/stack/nud.go80
-rw-r--r--pkg/tcpip/stack/nud_test.go37
-rw-r--r--pkg/tcpip/stack/packet_buffer.go6
-rw-r--r--pkg/tcpip/stack/rand.go4
-rw-r--r--pkg/tcpip/stack/registration.go15
-rw-r--r--pkg/tcpip/stack/stack.go91
-rw-r--r--pkg/tcpip/stack/stack_global_state.go72
-rw-r--r--pkg/tcpip/stack/stack_test.go199
-rw-r--r--pkg/tcpip/stack/tcp.go17
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go29
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go123
-rw-r--r--pkg/tcpip/stdclock.go16
-rw-r--r--pkg/tcpip/tcpip.go157
-rw-r--r--pkg/tcpip/tcpip_test.go21
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go285
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go453
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go11
-rw-r--r--pkg/tcpip/testutil/testutil.go12
-rw-r--r--pkg/tcpip/timer_test.go131
-rw-r--r--pkg/tcpip/transport/icmp/BUILD21
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go41
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go235
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go29
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go42
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/tcp/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/accept.go28
-rw-r--r--pkg/tcpip/transport/tcp/connect.go131
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go19
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go26
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go145
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go35
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/rack.go14
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go38
-rw-r--r--pkg/tcpip/transport/tcp/segment.go31
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go22
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/snd.go38
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go42
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go211
-rw-r--r--pkg/tcpip/transport/tcp/timer.go51
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go7
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go71
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go26
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go115
-rw-r--r--pkg/test/testutil/BUILD1
-rw-r--r--pkg/test/testutil/testutil.go2
-rw-r--r--pkg/urpc/urpc.go51
-rw-r--r--pkg/usermem/BUILD3
-rw-r--r--pkg/usermem/bytes_io.go5
-rw-r--r--pkg/usermem/marshal.go43
-rw-r--r--pkg/usermem/usermem.go8
-rw-r--r--pkg/usermem/usermem_test.go5
542 files changed, 9764 insertions, 6932 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index a461bb65e..eb004a7f6 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -15,12 +15,12 @@ go_library(
"bpf.go",
"capability.go",
"clone.go",
+ "context.go",
"dev.go",
"elf.go",
"epoll.go",
"epoll_amd64.go",
"epoll_arm64.go",
- "errors.go",
"errqueue.go",
"eventfd.go",
"exec.go",
@@ -77,6 +77,8 @@ go_library(
deps = [
"//pkg/abi",
"//pkg/bits",
+ "//pkg/context",
+ "//pkg/hostarch",
"//pkg/marshal",
"//pkg/marshal/primitive",
],
diff --git a/pkg/abi/linux/context.go b/pkg/abi/linux/context.go
new file mode 100644
index 000000000..d2dbba183
--- /dev/null
+++ b/pkg/abi/linux/context.go
@@ -0,0 +1,36 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// contextID is the linux package's type for context.Context.Value keys.
+type contextID int
+
+const (
+ // CtxSignalNoInfoFunc is a Context.Value key for a function to send signals.
+ CtxSignalNoInfoFunc contextID = iota
+)
+
+// SignalNoInfoFuncFromContext returns a callback function that can be used to send a
+// signal to the given context.
+func SignalNoInfoFuncFromContext(ctx context.Context) func(Signal) error {
+ if f := ctx.Value(CtxSignalNoInfoFunc); f != nil {
+ return f.(func(Signal) error)
+ }
+ return nil
+}
diff --git a/pkg/abi/linux/errno/BUILD b/pkg/abi/linux/errno/BUILD
new file mode 100644
index 000000000..d003342d5
--- /dev/null
+++ b/pkg/abi/linux/errno/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "errno",
+ srcs = ["errno.go"],
+ visibility = ["//visibility:public"],
+)
diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errno/errno.go
index b08b2687e..5a09c6605 100644
--- a/pkg/abi/linux/errors.go
+++ b/pkg/abi/linux/errno/errno.go
@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package linux
+// Package errno holds errno codes for abi/linux.
+package errno
// Errno represents a Linux errno value.
-type Errno int
+type Errno uint32
// Errno values from include/uapi/asm-generic/errno-base.h.
const (
@@ -60,8 +61,8 @@ const (
ENOLCK
ENOSYS
ENOTEMPTY
- ELOOP //40
- _ // Skip for EWOULDBLOCK = EAGAIN
+ ELOOP // 40
+ _ // Skip for EWOULDBLOCK = EAGAIN.
ENOMSG //42
EIDRM
ECHRNG
@@ -78,13 +79,13 @@ const (
ENOANO
EBADRQC
EBADSLT
- _ // Skip for EDEADLOCK = EDEADLK
+ _ // Skip for EDEADLOCK = EDEADLK.
EBFONT
ENOSTR // 60
ENODATA
ETIME
ENOSR
- ENONET
+ _ // Skip for ENOENT = ENONET.
ENOPKG
EREMOTE
ENOLINK
@@ -160,4 +161,5 @@ const (
const (
EWOULDBLOCK = EAGAIN
EDEADLOCK = EDEADLK
+ ENONET = ENOENT
)
diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go
index c59c9c136..ea4fdca0f 100644
--- a/pkg/abi/linux/ioctl_tun.go
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -26,4 +26,7 @@ const (
IFF_TAP = 0x0002
IFF_NO_PI = 0x1000
IFF_NOFILTER = 0x1000
+
+ // According to linux/if_tun.h "This flag has no real effect"
+ IFF_ONE_QUEUE = 0x2000
)
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index b088b207c..f8c0e891e 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -41,7 +41,6 @@ const (
// IP6T_ORIGINAL_DST is the ip6tables SOL_IPV6 socket option. Corresponds to
// the value in include/uapi/linux/netfilter_ipv6/ip6_tables.h.
-// TODO(gvisor.dev/issue/3549): Support IPv6 original destination.
const IP6T_ORIGINAL_DST = 80
// IP6TReplace is the argument for the IP6T_SO_SET_REPLACE sockopt. It
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index 6ca57ffbb..06a4c6401 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -165,7 +166,7 @@ const (
SIG_IGN = 1
)
-// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h
+// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h.
const (
SA_NOCLDSTOP = 0x00000001
SA_NOCLDWAIT = 0x00000002
@@ -179,21 +180,17 @@ const (
SA_ONESHOT = SA_RESETHAND
)
-// Signal info types.
+// Signal stack flags for signalstack(2), from include/uapi/linux/signal.h.
const (
- SI_MASK = 0xffff0000
- SI_KILL = 0 << 16
- SI_TIMER = 1 << 16
- SI_POLL = 2 << 16
- SI_FAULT = 3 << 16
- SI_CHLD = 4 << 16
- SI_RT = 5 << 16
- SI_MESGQ = 6 << 16
- SI_SYS = 7 << 16
+ SS_ONSTACK = 1
+ SS_DISABLE = 2
)
// SIGPOLL si_codes.
const (
+ // SI_POLL is defined as __SI_POLL in Linux 2.6.
+ SI_POLL = 2 << 16
+
// POLL_IN indicates that data input available.
POLL_IN = SI_POLL | 1
@@ -213,6 +210,75 @@ const (
POLL_HUP = SI_POLL | 6
)
+// Possible values for si_code.
+const (
+ // SI_USER is sent by kill, sigsend, raise.
+ SI_USER = 0
+
+ // SI_KERNEL is sent by the kernel from somewhere.
+ SI_KERNEL = 0x80
+
+ // SI_QUEUE is sent by sigqueue.
+ SI_QUEUE = -1
+
+ // SI_TIMER is sent by timer expiration.
+ SI_TIMER = -2
+
+ // SI_MESGQ is sent by real time mesq state change.
+ SI_MESGQ = -3
+
+ // SI_ASYNCIO is sent by AIO completion.
+ SI_ASYNCIO = -4
+
+ // SI_SIGIO is sent by queued SIGIO.
+ SI_SIGIO = -5
+
+ // SI_TKILL is sent by tkill system call.
+ SI_TKILL = -6
+
+ // SI_DETHREAD is sent by execve() killing subsidiary threads.
+ SI_DETHREAD = -7
+
+ // SI_ASYNCNL is sent by glibc async name lookup completion.
+ SI_ASYNCNL = -60
+)
+
+// CLD_* codes are only meaningful for SIGCHLD.
+const (
+ // CLD_EXITED indicates that a task exited.
+ CLD_EXITED = 1
+
+ // CLD_KILLED indicates that a task was killed by a signal.
+ CLD_KILLED = 2
+
+ // CLD_DUMPED indicates that a task was killed by a signal and then dumped
+ // core.
+ CLD_DUMPED = 3
+
+ // CLD_TRAPPED indicates that a task was stopped by ptrace.
+ CLD_TRAPPED = 4
+
+ // CLD_STOPPED indicates that a thread group completed a group stop.
+ CLD_STOPPED = 5
+
+ // CLD_CONTINUED indicates that a group-stopped thread group was continued.
+ CLD_CONTINUED = 6
+)
+
+// SYS_* codes are only meaningful for SIGSYS.
+const (
+ // SYS_SECCOMP indicates that a signal originates from seccomp.
+ SYS_SECCOMP = 1
+)
+
+// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify.
+const (
+ SIGEV_SIGNAL = 0
+ SIGEV_NONE = 1
+ SIGEV_THREAD = 2
+ SIGEV_THREAD_ID = 4
+)
+
// Sigevent represents struct sigevent.
//
// +marshal
@@ -227,10 +293,252 @@ type Sigevent struct {
UnRemainder [44]byte
}
-// Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify.
-const (
- SIGEV_SIGNAL = 0
- SIGEV_NONE = 1
- SIGEV_THREAD = 2
- SIGEV_THREAD_ID = 4
-)
+// SigAction represents struct sigaction.
+//
+// +marshal
+// +stateify savable
+type SigAction struct {
+ Handler uint64
+ Flags uint64
+ Restorer uint64
+ Mask SignalSet
+}
+
+// SignalStack represents information about a user stack, and is equivalent to
+// stack_t.
+//
+// +marshal
+// +stateify savable
+type SignalStack struct {
+ Addr uint64
+ Flags uint32
+ _ uint32
+ Size uint64
+}
+
+// Contains checks if the stack pointer is within this stack.
+func (s *SignalStack) Contains(sp hostarch.Addr) bool {
+ return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size)
+}
+
+// Top returns the stack's top address.
+func (s *SignalStack) Top() hostarch.Addr {
+ return hostarch.Addr(s.Addr + s.Size)
+}
+
+// IsEnabled returns true iff this signal stack is marked as enabled.
+func (s *SignalStack) IsEnabled() bool {
+ return s.Flags&SS_DISABLE == 0
+}
+
+// SignalInfo represents information about a signal being delivered, and is
+// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
+//
+// +marshal
+// +stateify savable
+type SignalInfo struct {
+ Signo int32 // Signal number
+ Errno int32 // Errno value
+ Code int32 // Signal code
+ _ uint32
+
+ // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
+ // are accessed through methods.
+ //
+ // For reference, here is the definition of _sifields: (_sigfault._trapno,
+ // which does not exist on x86, omitted for clarity)
+ //
+ // union {
+ // int _pad[SI_PAD_SIZE];
+ //
+ // /* kill() */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // } _kill;
+ //
+ // /* POSIX.1b timers */
+ // struct {
+ // __kernel_timer_t _tid; /* timer id */
+ // int _overrun; /* overrun count */
+ // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
+ // sigval_t _sigval; /* same as below */
+ // int _sys_private; /* not to be passed to user */
+ // } _timer;
+ //
+ // /* POSIX.1b signals */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // sigval_t _sigval;
+ // } _rt;
+ //
+ // /* SIGCHLD */
+ // struct {
+ // __kernel_pid_t _pid; /* which child */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // int _status; /* exit code */
+ // __ARCH_SI_CLOCK_T _utime;
+ // __ARCH_SI_CLOCK_T _stime;
+ // } _sigchld;
+ //
+ // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
+ // struct {
+ // void *_addr; /* faulting insn/memory ref. */
+ // short _addr_lsb; /* LSB of the reported address */
+ // } _sigfault;
+ //
+ // /* SIGPOLL */
+ // struct {
+ // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
+ // int _fd;
+ // } _sigpoll;
+ //
+ // /* SIGSYS */
+ // struct {
+ // void *_call_addr; /* calling user insn */
+ // int _syscall; /* triggering system call number */
+ // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
+ // } _sigsys;
+ // } _sifields;
+ //
+ // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
+ // bytes.
+ Fields [128 - 16]byte
+}
+
+// FixSignalCodeForUser fixes up si_code.
+//
+// The si_code we get from Linux may contain the kernel-specific code in the
+// top 16 bits if it's positive (e.g., from ptrace). Linux's
+// copy_siginfo_to_user does
+// err |= __put_user((short)from->si_code, &to->si_code);
+// to mask out those bits and we need to do the same.
+func (s *SignalInfo) FixSignalCodeForUser() {
+ if s.Code > 0 {
+ s.Code &= 0x0000ffff
+ }
+}
+
+// PID returns the si_pid field.
+func (s *SignalInfo) PID() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetPID mutates the si_pid field.
+func (s *SignalInfo) SetPID(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// UID returns the si_uid field.
+func (s *SignalInfo) UID() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetUID mutates the si_uid field.
+func (s *SignalInfo) SetUID(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
+func (s *SignalInfo) Sigval() uint64 {
+ return hostarch.ByteOrder.Uint64(s.Fields[8:16])
+}
+
+// SetSigval mutates the sigval field.
+func (s *SignalInfo) SetSigval(val uint64) {
+ hostarch.ByteOrder.PutUint64(s.Fields[8:16], val)
+}
+
+// TimerID returns the si_timerid field.
+func (s *SignalInfo) TimerID() TimerID {
+ return TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetTimerID sets the si_timerid field.
+func (s *SignalInfo) SetTimerID(val TimerID) {
+ hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Overrun returns the si_overrun field.
+func (s *SignalInfo) Overrun() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetOverrun sets the si_overrun field.
+func (s *SignalInfo) SetOverrun(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Addr returns the si_addr field.
+func (s *SignalInfo) Addr() uint64 {
+ return hostarch.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetAddr sets the si_addr field.
+func (s *SignalInfo) SetAddr(val uint64) {
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Status returns the si_status field.
+func (s *SignalInfo) Status() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetStatus mutates the si_status field.
+func (s *SignalInfo) SetStatus(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// CallAddr returns the si_call_addr field.
+func (s *SignalInfo) CallAddr() uint64 {
+ return hostarch.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetCallAddr mutates the si_call_addr field.
+func (s *SignalInfo) SetCallAddr(val uint64) {
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Syscall returns the si_syscall field.
+func (s *SignalInfo) Syscall() int32 {
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetSyscall mutates the si_syscall field.
+func (s *SignalInfo) SetSyscall(val int32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// Arch returns the si_arch field.
+func (s *SignalInfo) Arch() uint32 {
+ return hostarch.ByteOrder.Uint32(s.Fields[12:16])
+}
+
+// SetArch mutates the si_arch field.
+func (s *SignalInfo) SetArch(val uint32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[12:16], val)
+}
+
+// Band returns the si_band field.
+func (s *SignalInfo) Band() int64 {
+ return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8]))
+}
+
+// SetBand mutates the si_band field.
+func (s *SignalInfo) SetBand(val int64) {
+ // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
+ // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
+ // `int`. See siginfo.h.
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
+}
+
+// FD returns the si_fd field.
+func (s *SignalInfo) FD() uint32 {
+ return hostarch.ByteOrder.Uint32(s.Fields[8:12])
+}
+
+// SetFD mutates the si_fd field.
+func (s *SignalInfo) SetFD(val uint32) {
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], val)
+}
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
index 8ef837f27..1fa7a4f4f 100644
--- a/pkg/abi/linux/xattr.go
+++ b/pkg/abi/linux/xattr.go
@@ -23,6 +23,12 @@ const (
XATTR_CREATE = 1
XATTR_REPLACE = 2
+ XATTR_SECURITY_PREFIX = "security."
+ XATTR_SECURITY_PREFIX_LEN = len(XATTR_SECURITY_PREFIX)
+
+ XATTR_SYSTEM_PREFIX = "system."
+ XATTR_SYSTEM_PREFIX_LEN = len(XATTR_SYSTEM_PREFIX)
+
XATTR_TRUSTED_PREFIX = "trusted."
XATTR_TRUSTED_PREFIX_LEN = len(XATTR_TRUSTED_PREFIX)
diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go
index 629dae8f4..889568177 100644
--- a/pkg/control/server/server.go
+++ b/pkg/control/server/server.go
@@ -22,6 +22,7 @@ package server
import (
"os"
+ "time"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
@@ -65,13 +66,13 @@ func (s *Server) Wait() {
// Stop stops the server. Note that this function should only be called once
// and the server should not be used afterwards.
-func (s *Server) Stop() {
+func (s *Server) Stop(timeout time.Duration) {
s.socket.Close()
s.Wait()
// This will cause existing clients to be terminated safely. If the
// registered handlers have a Stop callback, it will be called.
- s.server.Stop()
+ s.server.Stop(timeout)
}
// StartServing starts listening for connect and spawns the main service
diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go
index 74a55a123..514592b08 100644
--- a/pkg/crypto/crypto_stdlib.go
+++ b/pkg/crypto/crypto_stdlib.go
@@ -22,11 +22,11 @@ import (
// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the
// public key, pub. Its return value records whether the signature is valid.
-func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool {
- return ecdsa.Verify(pub, hash, r, s)
+func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) {
+ return ecdsa.Verify(pub, hash, r, s), nil
}
// SumSha384 returns the SHA384 checksum of the data.
-func SumSha384(data []byte) (sum384 [sha512.Size384]byte) {
- return sha512.Sum384(data)
+func SumSha384(data []byte) ([sha512.Size384]byte, error) {
+ return sha512.Sum384(data), nil
}
diff --git a/pkg/errors/BUILD b/pkg/errors/BUILD
new file mode 100644
index 000000000..36ea5da0c
--- /dev/null
+++ b/pkg/errors/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "errors",
+ srcs = ["errors.go"],
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/abi/linux/errno"],
+)
diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go
new file mode 100644
index 000000000..7984a770e
--- /dev/null
+++ b/pkg/errors/errors.go
@@ -0,0 +1,40 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package errors holds the standardized error definition for gVisor.
+package errors
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
+)
+
+// Error represents a syscall errno with a descriptive message.
+type Error struct {
+ errno errno.Errno
+ message string
+}
+
+// New creates a new *Error.
+func New(err errno.Errno, message string) *Error {
+ return &Error{
+ errno: err,
+ message: message,
+ }
+}
+
+// Error implements error.Error.
+func (e *Error) Error() string { return e.message }
+
+// Errno returns the underlying errno.Errno value.
+func (e *Error) Errno() errno.Errno { return e.errno }
diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD
new file mode 100644
index 000000000..201727780
--- /dev/null
+++ b/pkg/errors/linuxerr/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "linuxerr",
+ srcs = ["linuxerr.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux/errno",
+ "//pkg/errors",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "linuxerr_test",
+ srcs = ["linuxerr_test.go"],
+ deps = [
+ ":linuxerr",
+ "//pkg/abi/linux/errno",
+ "//pkg/errors",
+ "//pkg/syserror",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go
new file mode 100644
index 000000000..9246f2e89
--- /dev/null
+++ b/pkg/errors/linuxerr/linuxerr.go
@@ -0,0 +1,347 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"),;
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package linuxerr contains syscall error codes exported as an error interface
+// pointers. This allows for fast comparison and return operations comperable
+// to unix.Errno constants.
+package linuxerr
+
+import (
+ "fmt"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
+ "gvisor.dev/gvisor/pkg/errors"
+)
+
+const maxErrno uint32 = errno.EHWPOISON + 1
+
+var (
+ NOERROR = errors.New(errno.NOERRNO, "not an error")
+ EPERM = errors.New(errno.EPERM, "operation not permitted")
+ ENOENT = errors.New(errno.ENOENT, "no such file or directory")
+ ESRCH = errors.New(errno.ESRCH, "no such process")
+ EINTR = errors.New(errno.EINTR, "interrupted system call")
+ EIO = errors.New(errno.EIO, "I/O error")
+ ENXIO = errors.New(errno.ENXIO, "no such device or address")
+ E2BIG = errors.New(errno.E2BIG, "argument list too long")
+ ENOEXEC = errors.New(errno.ENOEXEC, "exec format error")
+ EBADF = errors.New(errno.EBADF, "bad file number")
+ ECHILD = errors.New(errno.ECHILD, "no child processes")
+ EAGAIN = errors.New(errno.EAGAIN, "try again")
+ ENOMEM = errors.New(errno.ENOMEM, "out of memory")
+ EACCES = errors.New(errno.EACCES, "permission denied")
+ EFAULT = errors.New(errno.EFAULT, "bad address")
+ ENOTBLK = errors.New(errno.ENOTBLK, "block device required")
+ EBUSY = errors.New(errno.EBUSY, "device or resource busy")
+ EEXIST = errors.New(errno.EEXIST, "file exists")
+ EXDEV = errors.New(errno.EXDEV, "cross-device link")
+ ENODEV = errors.New(errno.ENODEV, "no such device")
+ ENOTDIR = errors.New(errno.ENOTDIR, "not a directory")
+ EISDIR = errors.New(errno.EISDIR, "is a directory")
+ EINVAL = errors.New(errno.EINVAL, "invalid argument")
+ ENFILE = errors.New(errno.ENFILE, "file table overflow")
+ EMFILE = errors.New(errno.EMFILE, "too many open files")
+ ENOTTY = errors.New(errno.ENOTTY, "not a typewriter")
+ ETXTBSY = errors.New(errno.ETXTBSY, "text file busy")
+ EFBIG = errors.New(errno.EFBIG, "file too large")
+ ENOSPC = errors.New(errno.ENOSPC, "no space left on device")
+ ESPIPE = errors.New(errno.ESPIPE, "illegal seek")
+ EROFS = errors.New(errno.EROFS, "read-only file system")
+ EMLINK = errors.New(errno.EMLINK, "too many links")
+ EPIPE = errors.New(errno.EPIPE, "broken pipe")
+ EDOM = errors.New(errno.EDOM, "math argument out of domain of func")
+ ERANGE = errors.New(errno.ERANGE, "math result not representable")
+
+ // Errno values from include/uapi/asm-generic/errno.h.
+ EDEADLK = errors.New(errno.EDEADLK, "resource deadlock would occur")
+ ENAMETOOLONG = errors.New(errno.ENAMETOOLONG, "file name too long")
+ ENOLCK = errors.New(errno.ENOLCK, "no record locks available")
+ ENOSYS = errors.New(errno.ENOSYS, "invalid system call number")
+ ENOTEMPTY = errors.New(errno.ENOTEMPTY, "directory not empty")
+ ELOOP = errors.New(errno.ELOOP, "too many symbolic links encountered")
+ ENOMSG = errors.New(errno.ENOMSG, "no message of desired type")
+ EIDRM = errors.New(errno.EIDRM, "identifier removed")
+ ECHRNG = errors.New(errno.ECHRNG, "channel number out of range")
+ EL2NSYNC = errors.New(errno.EL2NSYNC, "level 2 not synchronized")
+ EL3HLT = errors.New(errno.EL3HLT, "level 3 halted")
+ EL3RST = errors.New(errno.EL3RST, "level 3 reset")
+ ELNRNG = errors.New(errno.ELNRNG, "link number out of range")
+ EUNATCH = errors.New(errno.EUNATCH, "protocol driver not attached")
+ ENOCSI = errors.New(errno.ENOCSI, "no CSI structure available")
+ EL2HLT = errors.New(errno.EL2HLT, "level 2 halted")
+ EBADE = errors.New(errno.EBADE, "invalid exchange")
+ EBADR = errors.New(errno.EBADR, "invalid request descriptor")
+ EXFULL = errors.New(errno.EXFULL, "exchange full")
+ ENOANO = errors.New(errno.ENOANO, "no anode")
+ EBADRQC = errors.New(errno.EBADRQC, "invalid request code")
+ EBADSLT = errors.New(errno.EBADSLT, "invalid slot")
+ EBFONT = errors.New(errno.EBFONT, "bad font file format")
+ ENOSTR = errors.New(errno.ENOSTR, "device not a stream")
+ ENODATA = errors.New(errno.ENODATA, "no data available")
+ ETIME = errors.New(errno.ETIME, "timer expired")
+ ENOSR = errors.New(errno.ENOSR, "out of streams resources")
+ ENOPKG = errors.New(errno.ENOPKG, "package not installed")
+ EREMOTE = errors.New(errno.EREMOTE, "object is remote")
+ ENOLINK = errors.New(errno.ENOLINK, "link has been severed")
+ EADV = errors.New(errno.EADV, "advertise error")
+ ESRMNT = errors.New(errno.ESRMNT, "srmount error")
+ ECOMM = errors.New(errno.ECOMM, "communication error on send")
+ EPROTO = errors.New(errno.EPROTO, "protocol error")
+ EMULTIHOP = errors.New(errno.EMULTIHOP, "multihop attempted")
+ EDOTDOT = errors.New(errno.EDOTDOT, "RFS specific error")
+ EBADMSG = errors.New(errno.EBADMSG, "not a data message")
+ EOVERFLOW = errors.New(errno.EOVERFLOW, "value too large for defined data type")
+ ENOTUNIQ = errors.New(errno.ENOTUNIQ, "name not unique on network")
+ EBADFD = errors.New(errno.EBADFD, "file descriptor in bad state")
+ EREMCHG = errors.New(errno.EREMCHG, "remote address changed")
+ ELIBACC = errors.New(errno.ELIBACC, "can not access a needed shared library")
+ ELIBBAD = errors.New(errno.ELIBBAD, "accessing a corrupted shared library")
+ ELIBSCN = errors.New(errno.ELIBSCN, ".lib section in a.out corrupted")
+ ELIBMAX = errors.New(errno.ELIBMAX, "attempting to link in too many shared libraries")
+ ELIBEXEC = errors.New(errno.ELIBEXEC, "cannot exec a shared library directly")
+ EILSEQ = errors.New(errno.EILSEQ, "illegal byte sequence")
+ ERESTART = errors.New(errno.ERESTART, "interrupted system call should be restarted")
+ ESTRPIPE = errors.New(errno.ESTRPIPE, "streams pipe error")
+ EUSERS = errors.New(errno.EUSERS, "too many users")
+ ENOTSOCK = errors.New(errno.ENOTSOCK, "socket operation on non-socket")
+ EDESTADDRREQ = errors.New(errno.EDESTADDRREQ, "destination address required")
+ EMSGSIZE = errors.New(errno.EMSGSIZE, "message too long")
+ EPROTOTYPE = errors.New(errno.EPROTOTYPE, "protocol wrong type for socket")
+ ENOPROTOOPT = errors.New(errno.ENOPROTOOPT, "protocol not available")
+ EPROTONOSUPPORT = errors.New(errno.EPROTONOSUPPORT, "protocol not supported")
+ ESOCKTNOSUPPORT = errors.New(errno.ESOCKTNOSUPPORT, "socket type not supported")
+ EOPNOTSUPP = errors.New(errno.EOPNOTSUPP, "operation not supported on transport endpoint")
+ EPFNOSUPPORT = errors.New(errno.EPFNOSUPPORT, "protocol family not supported")
+ EAFNOSUPPORT = errors.New(errno.EAFNOSUPPORT, "address family not supported by protocol")
+ EADDRINUSE = errors.New(errno.EADDRINUSE, "address already in use")
+ EADDRNOTAVAIL = errors.New(errno.EADDRNOTAVAIL, "cannot assign requested address")
+ ENETDOWN = errors.New(errno.ENETDOWN, "network is down")
+ ENETUNREACH = errors.New(errno.ENETUNREACH, "network is unreachable")
+ ENETRESET = errors.New(errno.ENETRESET, "network dropped connection because of reset")
+ ECONNABORTED = errors.New(errno.ECONNABORTED, "software caused connection abort")
+ ECONNRESET = errors.New(errno.ECONNRESET, "connection reset by peer")
+ ENOBUFS = errors.New(errno.ENOBUFS, "no buffer space available")
+ EISCONN = errors.New(errno.EISCONN, "transport endpoint is already connected")
+ ENOTCONN = errors.New(errno.ENOTCONN, "transport endpoint is not connected")
+ ESHUTDOWN = errors.New(errno.ESHUTDOWN, "cannot send after transport endpoint shutdown")
+ ETOOMANYREFS = errors.New(errno.ETOOMANYREFS, "too many references: cannot splice")
+ ETIMEDOUT = errors.New(errno.ETIMEDOUT, "connection timed out")
+ ECONNREFUSED = errors.New(errno.ECONNREFUSED, "connection refused")
+ EHOSTDOWN = errors.New(errno.EHOSTDOWN, "host is down")
+ EHOSTUNREACH = errors.New(errno.EHOSTUNREACH, "no route to host")
+ EALREADY = errors.New(errno.EALREADY, "operation already in progress")
+ EINPROGRESS = errors.New(errno.EINPROGRESS, "operation now in progress")
+ ESTALE = errors.New(errno.ESTALE, "stale file handle")
+ EUCLEAN = errors.New(errno.EUCLEAN, "structure needs cleaning")
+ ENOTNAM = errors.New(errno.ENOTNAM, "not a XENIX named type file")
+ ENAVAIL = errors.New(errno.ENAVAIL, "no XENIX semaphores available")
+ EISNAM = errors.New(errno.EISNAM, "is a named type file")
+ EREMOTEIO = errors.New(errno.EREMOTEIO, "remote I/O error")
+ EDQUOT = errors.New(errno.EDQUOT, "quota exceeded")
+ ENOMEDIUM = errors.New(errno.ENOMEDIUM, "no medium found")
+ EMEDIUMTYPE = errors.New(errno.EMEDIUMTYPE, "wrong medium type")
+ ECANCELED = errors.New(errno.ECANCELED, "operation Canceled")
+ ENOKEY = errors.New(errno.ENOKEY, "required key not available")
+ EKEYEXPIRED = errors.New(errno.EKEYEXPIRED, "key has expired")
+ EKEYREVOKED = errors.New(errno.EKEYREVOKED, "key has been revoked")
+ EKEYREJECTED = errors.New(errno.EKEYREJECTED, "key was rejected by service")
+ EOWNERDEAD = errors.New(errno.EOWNERDEAD, "owner died")
+ ENOTRECOVERABLE = errors.New(errno.ENOTRECOVERABLE, "state not recoverable")
+ ERFKILL = errors.New(errno.ERFKILL, "operation not possible due to RF-kill")
+ EHWPOISON = errors.New(errno.EHWPOISON, "memory page has hardware error")
+
+ // Errors equivalent to other errors.
+ EWOULDBLOCK = EAGAIN
+ EDEADLOCK = EDEADLK
+ ENONET = ENOENT
+)
+
+// A nil *errors.Error denotes no error and is placed at the 0 index of
+// errorSlice. Thus, any other empty index should not be nil or a valid error.
+// This marks that index as an invalid error so any comparison to nil or a
+// valid linuxerr fails.
+var errNotValidError = errors.New(errno.Errno(maxErrno), "not a valid error")
+
+// The following errorSlice holds errors by errno for fast translation between
+// errnos (especially uint32(sycall.Errno)) and *Error.
+var errorSlice = []*errors.Error{
+ // Errno values from include/uapi/asm-generic/errno-base.h.
+ errno.NOERRNO: NOERROR,
+ errno.EPERM: EPERM,
+ errno.ENOENT: ENOENT,
+ errno.ESRCH: ESRCH,
+ errno.EINTR: EINTR,
+ errno.EIO: EIO,
+ errno.ENXIO: ENXIO,
+ errno.E2BIG: E2BIG,
+ errno.ENOEXEC: ENOEXEC,
+ errno.EBADF: EBADF,
+ errno.ECHILD: ECHILD,
+ errno.EAGAIN: EAGAIN,
+ errno.ENOMEM: ENOMEM,
+ errno.EACCES: EACCES,
+ errno.EFAULT: EFAULT,
+ errno.ENOTBLK: ENOTBLK,
+ errno.EBUSY: EBUSY,
+ errno.EEXIST: EEXIST,
+ errno.EXDEV: EXDEV,
+ errno.ENODEV: ENODEV,
+ errno.ENOTDIR: ENOTDIR,
+ errno.EISDIR: EISDIR,
+ errno.EINVAL: EINVAL,
+ errno.ENFILE: ENFILE,
+ errno.EMFILE: EMFILE,
+ errno.ENOTTY: ENOTTY,
+ errno.ETXTBSY: ETXTBSY,
+ errno.EFBIG: EFBIG,
+ errno.ENOSPC: ENOSPC,
+ errno.ESPIPE: ESPIPE,
+ errno.EROFS: EROFS,
+ errno.EMLINK: EMLINK,
+ errno.EPIPE: EPIPE,
+ errno.EDOM: EDOM,
+ errno.ERANGE: ERANGE,
+
+ // Errno values from include/uapi/asm-generic/errno.h.
+ errno.EDEADLK: EDEADLK,
+ errno.ENAMETOOLONG: ENAMETOOLONG,
+ errno.ENOLCK: ENOLCK,
+ errno.ENOSYS: ENOSYS,
+ errno.ENOTEMPTY: ENOTEMPTY,
+ errno.ELOOP: ELOOP,
+ errno.ELOOP + 1: errNotValidError, // No valid errno between ELOOP and ENOMSG.
+ errno.ENOMSG: ENOMSG,
+ errno.EIDRM: EIDRM,
+ errno.ECHRNG: ECHRNG,
+ errno.EL2NSYNC: EL2NSYNC,
+ errno.EL3HLT: EL3HLT,
+ errno.EL3RST: EL3RST,
+ errno.ELNRNG: ELNRNG,
+ errno.EUNATCH: EUNATCH,
+ errno.ENOCSI: ENOCSI,
+ errno.EL2HLT: EL2HLT,
+ errno.EBADE: EBADE,
+ errno.EBADR: EBADR,
+ errno.EXFULL: EXFULL,
+ errno.ENOANO: ENOANO,
+ errno.EBADRQC: EBADRQC,
+ errno.EBADSLT: EBADSLT,
+ errno.EBADSLT + 1: errNotValidError, // No valid errno between EBADSLT and ENOPKG.
+ errno.EBFONT: EBFONT,
+ errno.ENOSTR: ENOSTR,
+ errno.ENODATA: ENODATA,
+ errno.ETIME: ETIME,
+ errno.ENOSR: ENOSR,
+ errno.ENOSR + 1: errNotValidError, // No valid errno betweeen ENOSR and ENOPKG.
+ errno.ENOPKG: ENOPKG,
+ errno.EREMOTE: EREMOTE,
+ errno.ENOLINK: ENOLINK,
+ errno.EADV: EADV,
+ errno.ESRMNT: ESRMNT,
+ errno.ECOMM: ECOMM,
+ errno.EPROTO: EPROTO,
+ errno.EMULTIHOP: EMULTIHOP,
+ errno.EDOTDOT: EDOTDOT,
+ errno.EBADMSG: EBADMSG,
+ errno.EOVERFLOW: EOVERFLOW,
+ errno.ENOTUNIQ: ENOTUNIQ,
+ errno.EBADFD: EBADFD,
+ errno.EREMCHG: EREMCHG,
+ errno.ELIBACC: ELIBACC,
+ errno.ELIBBAD: ELIBBAD,
+ errno.ELIBSCN: ELIBSCN,
+ errno.ELIBMAX: ELIBMAX,
+ errno.ELIBEXEC: ELIBEXEC,
+ errno.EILSEQ: EILSEQ,
+ errno.ERESTART: ERESTART,
+ errno.ESTRPIPE: ESTRPIPE,
+ errno.EUSERS: EUSERS,
+ errno.ENOTSOCK: ENOTSOCK,
+ errno.EDESTADDRREQ: EDESTADDRREQ,
+ errno.EMSGSIZE: EMSGSIZE,
+ errno.EPROTOTYPE: EPROTOTYPE,
+ errno.ENOPROTOOPT: ENOPROTOOPT,
+ errno.EPROTONOSUPPORT: EPROTONOSUPPORT,
+ errno.ESOCKTNOSUPPORT: ESOCKTNOSUPPORT,
+ errno.EOPNOTSUPP: EOPNOTSUPP,
+ errno.EPFNOSUPPORT: EPFNOSUPPORT,
+ errno.EAFNOSUPPORT: EAFNOSUPPORT,
+ errno.EADDRINUSE: EADDRINUSE,
+ errno.EADDRNOTAVAIL: EADDRNOTAVAIL,
+ errno.ENETDOWN: ENETDOWN,
+ errno.ENETUNREACH: ENETUNREACH,
+ errno.ENETRESET: ENETRESET,
+ errno.ECONNABORTED: ECONNABORTED,
+ errno.ECONNRESET: ECONNRESET,
+ errno.ENOBUFS: ENOBUFS,
+ errno.EISCONN: EISCONN,
+ errno.ENOTCONN: ENOTCONN,
+ errno.ESHUTDOWN: ESHUTDOWN,
+ errno.ETOOMANYREFS: ETOOMANYREFS,
+ errno.ETIMEDOUT: ETIMEDOUT,
+ errno.ECONNREFUSED: ECONNREFUSED,
+ errno.EHOSTDOWN: EHOSTDOWN,
+ errno.EHOSTUNREACH: EHOSTUNREACH,
+ errno.EALREADY: EALREADY,
+ errno.EINPROGRESS: EINPROGRESS,
+ errno.ESTALE: ESTALE,
+ errno.EUCLEAN: EUCLEAN,
+ errno.ENOTNAM: ENOTNAM,
+ errno.ENAVAIL: ENAVAIL,
+ errno.EISNAM: EISNAM,
+ errno.EREMOTEIO: EREMOTEIO,
+ errno.EDQUOT: EDQUOT,
+ errno.ENOMEDIUM: ENOMEDIUM,
+ errno.EMEDIUMTYPE: EMEDIUMTYPE,
+ errno.ECANCELED: ECANCELED,
+ errno.ENOKEY: ENOKEY,
+ errno.EKEYEXPIRED: EKEYEXPIRED,
+ errno.EKEYREVOKED: EKEYREVOKED,
+ errno.EKEYREJECTED: EKEYREJECTED,
+ errno.EOWNERDEAD: EOWNERDEAD,
+ errno.ENOTRECOVERABLE: ENOTRECOVERABLE,
+ errno.ERFKILL: ERFKILL,
+ errno.EHWPOISON: EHWPOISON,
+}
+
+// ErrorFromErrno gets an error from the list and panics if an invalid entry is requested.
+func ErrorFromErrno(e errno.Errno) *errors.Error {
+ err := errorSlice[e]
+ // Done this way because a single comparison in benchmarks is 2-3 faster
+ // than something like ( if err == nil && err > 0 ).
+ if err != errNotValidError {
+ return err
+ }
+ panic(fmt.Sprintf("invalid error requested with errno: %d", e))
+}
+
+// Equals compars a linuxerr to a given error
+// TODO(b/34162363): Remove when syserror is removed.
+func Equals(e *errors.Error, err error) bool {
+ if err == nil {
+ return e == NOERROR || e == nil
+ }
+ if e == nil {
+ return err == NOERROR || err == unix.Errno(0)
+ }
+
+ switch err.(type) {
+ case *errors.Error:
+ return e == err
+ case unix.Errno, error:
+ return unix.Errno(e.Errno()) == err
+ }
+ return false
+}
diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go
new file mode 100644
index 000000000..b1250d24f
--- /dev/null
+++ b/pkg/errors/linuxerr/linuxerr_test.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syserror_test
+
+import (
+ "errors"
+ "io"
+ "io/fs"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
+ gErrors "gvisor.dev/gvisor/pkg/errors"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var globalError error
+
+func BenchmarkAssignUnix(b *testing.B) {
+ for i := b.N; i > 0; i-- {
+ globalError = unix.EINVAL
+ }
+}
+
+func BenchmarkAssignLinuxerr(b *testing.B) {
+ for i := b.N; i > 0; i-- {
+ globalError = linuxerr.EINVAL
+ }
+}
+
+func BenchmarkAssignSyserror(b *testing.B) {
+ for i := b.N; i > 0; i-- {
+ globalError = linuxerr.ENOMSG
+ }
+}
+
+func BenchmarkCompareUnix(b *testing.B) {
+ globalError = unix.EAGAIN
+ j := 0
+ for i := b.N; i > 0; i-- {
+ if globalError == unix.EINVAL {
+ j++
+ }
+ }
+}
+
+func BenchmarkCompareLinuxerr(b *testing.B) {
+ globalError = linuxerr.E2BIG
+ j := 0
+ for i := b.N; i > 0; i-- {
+ if globalError == linuxerr.EINVAL {
+ j++
+ }
+ }
+}
+
+func BenchmarkCompareSyserror(b *testing.B) {
+ globalError = syserror.EAGAIN
+ j := 0
+ for i := b.N; i > 0; i-- {
+ if globalError == syserror.EACCES {
+ j++
+ }
+ }
+}
+
+func BenchmarkSwitchUnix(b *testing.B) {
+ globalError = unix.EPERM
+ j := 0
+ for i := b.N; i > 0; i-- {
+ switch globalError {
+ case unix.EINVAL:
+ j++
+ case unix.EINTR:
+ j += 2
+ case unix.EAGAIN:
+ j += 3
+ }
+ }
+}
+
+func BenchmarkSwitchLinuxerr(b *testing.B) {
+ globalError = linuxerr.EPERM
+ j := 0
+ for i := b.N; i > 0; i-- {
+ switch globalError {
+ case linuxerr.EINVAL:
+ j++
+ case linuxerr.EINTR:
+ j += 2
+ case linuxerr.EAGAIN:
+ j += 3
+ }
+ }
+}
+
+func BenchmarkSwitchSyserror(b *testing.B) {
+ globalError = syserror.EPERM
+ j := 0
+ for i := b.N; i > 0; i-- {
+ switch globalError {
+ case syserror.EACCES:
+ j++
+ case syserror.EINTR:
+ j += 2
+ case syserror.EAGAIN:
+ j += 3
+ }
+ }
+}
+
+func BenchmarkReturnUnix(b *testing.B) {
+ var localError error
+ f := func() error {
+ return unix.EINVAL
+ }
+ for i := b.N; i > 0; i-- {
+ localError = f()
+ }
+ if localError != nil {
+ return
+ }
+}
+
+func BenchmarkReturnLinuxerr(b *testing.B) {
+ var localError error
+ f := func() error {
+ return linuxerr.EINVAL
+ }
+ for i := b.N; i > 0; i-- {
+ localError = f()
+ }
+ if localError != nil {
+ return
+ }
+}
+
+func BenchmarkConvertUnixLinuxerr(b *testing.B) {
+ var localError error
+ for i := b.N; i > 0; i-- {
+ localError = linuxerr.ErrorFromErrno(errno.Errno(unix.EINVAL))
+ }
+ if localError != nil {
+ return
+ }
+}
+
+func BenchmarkConvertUnixLinuxerrZero(b *testing.B) {
+ var localError error
+ for i := b.N; i > 0; i-- {
+ localError = linuxerr.ErrorFromErrno(errno.Errno(0))
+ }
+ if localError != nil {
+ return
+ }
+}
+
+type translationTestTable struct {
+ fn string
+ errIn error
+ syscallErrorIn unix.Errno
+ expectedBool bool
+ expectedTranslation unix.Errno
+}
+
+func TestErrorTranslation(t *testing.T) {
+ myError := errors.New("My test error")
+ myError2 := errors.New("Another test error")
+ testTable := []translationTestTable{
+ {"TranslateError", myError, 0, false, 0},
+ {"TranslateError", myError2, 0, false, 0},
+ {"AddErrorTranslation", myError, unix.EAGAIN, true, 0},
+ {"AddErrorTranslation", myError, unix.EAGAIN, false, 0},
+ {"AddErrorTranslation", myError, unix.EPERM, false, 0},
+ {"TranslateError", myError, 0, true, unix.EAGAIN},
+ {"TranslateError", myError2, 0, false, 0},
+ {"AddErrorTranslation", myError2, unix.EPERM, true, 0},
+ {"AddErrorTranslation", myError2, unix.EPERM, false, 0},
+ {"AddErrorTranslation", myError2, unix.EAGAIN, false, 0},
+ {"TranslateError", myError, 0, true, unix.EAGAIN},
+ {"TranslateError", myError2, 0, true, unix.EPERM},
+ }
+ for _, tt := range testTable {
+ switch tt.fn {
+ case "TranslateError":
+ err, ok := syserror.TranslateError(tt.errIn)
+ if ok != tt.expectedBool {
+ t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
+ } else if err != tt.expectedTranslation {
+ t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation)
+ }
+ case "AddErrorTranslation":
+ ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn)
+ if ok != tt.expectedBool {
+ t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
+ }
+ default:
+ t.Fatalf("Unknown function %v", tt.fn)
+ }
+ }
+}
+
+func TestSyscallErrnoToErrors(t *testing.T) {
+ for _, tc := range []struct {
+ errno syscall.Errno
+ err *gErrors.Error
+ }{
+ {errno: syscall.EACCES, err: linuxerr.EACCES},
+ {errno: syscall.EAGAIN, err: linuxerr.EAGAIN},
+ {errno: syscall.EBADF, err: linuxerr.EBADF},
+ {errno: syscall.EBUSY, err: linuxerr.EBUSY},
+ {errno: syscall.EDOM, err: linuxerr.EDOM},
+ {errno: syscall.EEXIST, err: linuxerr.EEXIST},
+ {errno: syscall.EFAULT, err: linuxerr.EFAULT},
+ {errno: syscall.EFBIG, err: linuxerr.EFBIG},
+ {errno: syscall.EINTR, err: linuxerr.EINTR},
+ {errno: syscall.EINVAL, err: linuxerr.EINVAL},
+ {errno: syscall.EIO, err: linuxerr.EIO},
+ {errno: syscall.ENOTDIR, err: linuxerr.ENOTDIR},
+ {errno: syscall.ENOTTY, err: linuxerr.ENOTTY},
+ {errno: syscall.EPERM, err: linuxerr.EPERM},
+ {errno: syscall.EPIPE, err: linuxerr.EPIPE},
+ {errno: syscall.ESPIPE, err: linuxerr.ESPIPE},
+ {errno: syscall.EWOULDBLOCK, err: linuxerr.EAGAIN},
+ } {
+ t.Run(tc.errno.Error(), func(t *testing.T) {
+ e := linuxerr.ErrorFromErrno(errno.Errno(tc.errno))
+ if e != tc.err {
+ t.Fatalf("Mismatch errors: want: %+v (%d) got: %+v %d", tc.err, tc.err.Errno(), e, e.Errno())
+ }
+ })
+ }
+}
+
+// TestEqualsMethod tests that the Equals method correctly compares syerror,
+// unix.Errno and linuxerr.
+// TODO (b/34162363): Remove this.
+func TestEqualsMethod(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ linuxErr []*gErrors.Error
+ err []error
+ equal bool
+ }{
+ {
+ name: "compare nil",
+ linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR},
+ err: []error{nil, linuxerr.NOERROR, unix.Errno(0)},
+ equal: true,
+ },
+ {
+ name: "linuxerr nil error not",
+ linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR},
+ err: []error{unix.Errno(1), linuxerr.EPERM, syserror.EACCES},
+ equal: false,
+ },
+ {
+ name: "linuxerr not nil error nil",
+ linuxErr: []*gErrors.Error{linuxerr.ENOENT},
+ err: []error{nil, unix.Errno(0), linuxerr.NOERROR},
+ equal: false,
+ },
+ {
+ name: "equal errors",
+ linuxErr: []*gErrors.Error{linuxerr.ESRCH},
+ err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())},
+ equal: true,
+ },
+ {
+ name: "unequal errors",
+ linuxErr: []*gErrors.Error{linuxerr.ENOENT},
+ err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())},
+ equal: false,
+ },
+ {
+ name: "other error",
+ linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR, linuxerr.E2BIG, linuxerr.EINVAL},
+ err: []error{fs.ErrInvalid, io.EOF},
+ equal: false,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ for _, le := range tc.linuxErr {
+ for _, e := range tc.err {
+ if linuxerr.Equals(le, e) != tc.equal {
+ t.Fatalf("Expected %t from Equals method for linuxerr: %s %T and error: %s %T", tc.equal, le, le, e, e)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index 9730b88c1..c810c7946 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -10,9 +10,7 @@ go_library(
"flipcall_unsafe.go",
"futex_linux.go",
"io.go",
- "packet_window_allocator.go",
- "packet_window_mmap_amd64.go",
- "packet_window_mmap_arm64.go",
+ "packet_window.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 8d8309a73..f0e4ff487 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -22,6 +22,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/memutil"
)
// An Endpoint provides the ability to synchronously transfer data and control
@@ -96,9 +97,9 @@ func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...
if pwd.Length > math.MaxUint32 {
return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
}
- m, e := packetWindowMmap(pwd)
- if e != 0 {
- return fmt.Errorf("failed to mmap packet window: %v", e)
+ m, err := memutil.MapFile(0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ if err != nil {
+ return fmt.Errorf("failed to mmap packet window: %v", err)
}
ep.packet = m
ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes)
diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window.go
index 9122c97b7..9122c97b7 100644
--- a/pkg/flipcall/packet_window_allocator.go
+++ b/pkg/flipcall/packet_window.go
diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD
deleted file mode 100644
index f4e9a6af9..000000000
--- a/pkg/iovec/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "iovec",
- srcs = ["iovec.go"],
- visibility = ["//:sandbox"],
- deps = ["@org_golang_x_sys//unix:go_default_library"],
-)
-
-go_test(
- name = "iovec_test",
- size = "small",
- srcs = ["iovec_test.go"],
- library = ":iovec",
- deps = ["@org_golang_x_sys//unix:go_default_library"],
-)
diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go
deleted file mode 100644
index a281c05b6..000000000
--- a/pkg/iovec/iovec.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux
-
-// Package iovec provides helpers to interact with vectorized I/O on host
-// system.
-package iovec
-
-import (
- "golang.org/x/sys/unix"
-)
-
-// MaxIovs is the maximum number of iovecs host platform can accept.
-var MaxIovs = 1024
-
-// Builder is a builder for slice of unix.Iovec.
-type Builder struct {
- iovec []unix.Iovec
- storage [8]unix.Iovec
-
- // overflow tracks the last buffer when iovec length is at MaxIovs.
- overflow []byte
-}
-
-// Add adds buf to b preparing to be written. Zero-length buf won't be added.
-func (b *Builder) Add(buf []byte) {
- if len(buf) == 0 {
- return
- }
- if b.iovec == nil {
- b.iovec = b.storage[:0]
- }
- if len(b.iovec) >= MaxIovs {
- b.addByAppend(buf)
- return
- }
-
- b.iovec = append(b.iovec, unix.Iovec{Base: &buf[0]})
- b.iovec[len(b.iovec)-1].SetLen(len(buf))
-
- // Keep the last buf if iovec is at max capacity. We will need to append to it
- // for later bufs.
- if len(b.iovec) == MaxIovs {
- n := len(buf)
- b.overflow = buf[:n:n]
- }
-}
-
-func (b *Builder) addByAppend(buf []byte) {
- b.overflow = append(b.overflow, buf...)
- b.iovec[len(b.iovec)-1] = unix.Iovec{Base: &b.overflow[0]}
- b.iovec[len(b.iovec)-1].SetLen(len(b.overflow))
-}
-
-// Build returns the final Iovec slice. The length of returned iovec will not
-// excceed MaxIovs.
-func (b *Builder) Build() []unix.Iovec {
- return b.iovec
-}
diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go
deleted file mode 100644
index f6deb4208..000000000
--- a/pkg/iovec/iovec_test.go
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux
-
-package iovec
-
-import (
- "bytes"
- "fmt"
- "testing"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-func TestBuilderEmpty(t *testing.T) {
- var builder Builder
- iovecs := builder.Build()
- if got, want := len(iovecs), 0; got != want {
- t.Errorf("len(iovecs) = %d, want %d", got, want)
- }
-}
-
-func TestBuilderBuild(t *testing.T) {
- a := []byte{1, 2}
- b := []byte{3, 4, 5}
-
- var builder Builder
- builder.Add(a)
- builder.Add(b)
- builder.Add(nil) // Nil slice won't be added.
- builder.Add([]byte{}) // Empty slice won't be added.
- iovecs := builder.Build()
-
- if got, want := len(iovecs), 2; got != want {
- t.Fatalf("len(iovecs) = %d, want %d", got, want)
- }
- for i, data := range [][]byte{a, b} {
- if got, want := *iovecs[i].Base, data[0]; got != want {
- t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want)
- }
- if got, want := iovecs[i].Len, uint64(len(data)); got != want {
- t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want)
- }
- }
-}
-
-func TestBuilderBuildMaxIov(t *testing.T) {
- for _, test := range []struct {
- numIov int
- }{
- {
- numIov: MaxIovs - 1,
- },
- {
- numIov: MaxIovs,
- },
- {
- numIov: MaxIovs + 1,
- },
- {
- numIov: MaxIovs + 10,
- },
- } {
- name := fmt.Sprintf("numIov=%v", test.numIov)
- t.Run(name, func(t *testing.T) {
- var data []byte
- var builder Builder
- for i := 0; i < test.numIov; i++ {
- buf := []byte{byte(i)}
- builder.Add(buf)
- data = append(data, buf...)
- }
- iovec := builder.Build()
-
- // Check the expected length of iovec.
- wantNum := test.numIov
- if wantNum > MaxIovs {
- wantNum = MaxIovs
- }
- if got, want := len(iovec), wantNum; got != want {
- t.Errorf("len(iovec) = %d, want %d", got, want)
- }
-
- // Test a real read-write.
- var fds [2]int
- if err := unix.Pipe(fds[:]); err != nil {
- t.Fatalf("Pipe: %v", err)
- }
- defer unix.Close(fds[0])
- defer unix.Close(fds[1])
-
- wrote, _, e := unix.RawSyscall(unix.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
- if int(wrote) != len(data) || e != 0 {
- t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data))
- }
-
- got := make([]byte, len(data))
- if n, err := unix.Read(fds[0], got); n != len(got) || err != nil {
- t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got))
- }
-
- if !bytes.Equal(got, data) {
- t.Errorf("read: got data %v, want %v", got, data)
- }
- })
- }
-}
diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD
index 190b57c29..6e5ce136d 100644
--- a/pkg/marshal/primitive/BUILD
+++ b/pkg/marshal/primitive/BUILD
@@ -12,9 +12,7 @@ go_library(
"//:sandbox",
],
deps = [
- "//pkg/context",
"//pkg/hostarch",
"//pkg/marshal",
- "//pkg/usermem",
],
)
diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
index 6f38992b7..1c49cf082 100644
--- a/pkg/marshal/primitive/primitive.go
+++ b/pkg/marshal/primitive/primitive.go
@@ -19,10 +19,8 @@ package primitive
import (
"io"
- "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Int8 is a marshal.Marshallable implementation for int8.
@@ -400,26 +398,3 @@ func CopyStringOut(cc marshal.CopyContext, addr hostarch.Addr, src string) (int,
srcP := ByteSlice(src)
return srcP.CopyOut(cc, addr)
}
-
-// IOCopyContext wraps an object implementing hostarch.IO to implement
-// marshal.CopyContext.
-type IOCopyContext struct {
- Ctx context.Context
- IO usermem.IO
- Opts usermem.IOOpts
-}
-
-// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
-func (i *IOCopyContext) CopyScratchBuffer(size int) []byte {
- return make([]byte, size)
-}
-
-// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
-func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) {
- return i.IO.CopyOut(i.Ctx, addr, b, i.Opts)
-}
-
-// CopyInBytes implements marshal.CopyContext.CopyInBytes.
-func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) {
- return i.IO.CopyIn(i.Ctx, addr, b, i.Opts)
-}
diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD
index 9d07d98b4..bea595286 100644
--- a/pkg/memutil/BUILD
+++ b/pkg/memutil/BUILD
@@ -4,7 +4,11 @@ package(licenses = ["notice"])
go_library(
name = "memutil",
- srcs = ["memutil_unsafe.go"],
+ srcs = [
+ "memfd_linux_unsafe.go",
+ "memutil.go",
+ "mmap.go",
+ ],
visibility = ["//visibility:public"],
deps = ["@org_golang_x_sys//unix:go_default_library"],
)
diff --git a/pkg/memutil/memutil_unsafe.go b/pkg/memutil/memfd_linux_unsafe.go
index 6676d1ce3..504382213 100644
--- a/pkg/memutil/memutil_unsafe.go
+++ b/pkg/memutil/memfd_linux_unsafe.go
@@ -14,7 +14,6 @@
// +build linux
-// Package memutil provides a wrapper for the memfd_create() system call.
package memutil
import (
diff --git a/pkg/tcpip/time.s b/pkg/memutil/memutil.go
index fb37360ac..3185882fd 100644
--- a/pkg/tcpip/time.s
+++ b/pkg/memutil/memutil.go
@@ -12,4 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Empty assembly file so empty func definitions work.
+// Package memutil provides utilities for working with shared memory files.
+package memutil
diff --git a/pkg/flipcall/packet_window_mmap_amd64.go b/pkg/memutil/mmap.go
index ced587a2a..7c939293f 100644
--- a/pkg/flipcall/packet_window_mmap_amd64.go
+++ b/pkg/memutil/mmap.go
@@ -12,12 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package flipcall
+package memutil
-import "golang.org/x/sys/unix"
+import (
+ "golang.org/x/sys/unix"
+)
-// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
-func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, unix.Errno) {
- m, _, err := unix.RawSyscall6(unix.SYS_MMAP, 0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
- return m, err
+// MapFile returns a memory mapping configured by the given options as per
+// mmap(2).
+func MapFile(addr, len, prot, flags, fd, offset uintptr) (uintptr, error) {
+ m, _, e := unix.RawSyscall6(unix.SYS_MMAP, addr, len, prot, flags, fd, offset)
+ if e != 0 {
+ return 0, e
+ }
+ return m, nil
}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index ac7868ad9..0b961d3d9 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -151,21 +151,21 @@ type VerityDescriptor struct {
Mode uint32
UID uint32
GID uint32
- Children map[string]struct{}
+ Children []string
SymlinkTarget string
RootHash []byte
}
-func (d *VerityDescriptor) String() string {
+func (d *VerityDescriptor) encode() []byte {
b := new(bytes.Buffer)
e := gob.NewEncoder(b)
- e.Encode(d.Children)
- return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, Children: %v, Symlink: %s, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, b.Bytes(), d.SymlinkTarget, d.RootHash)
+ e.Encode(d)
+ return b.Bytes()
}
// verify generates a hash from d, and compares it with expected.
func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error {
- h, err := hashData([]byte(d.String()), hashAlgorithms)
+ h, err := hashData(d.encode(), hashAlgorithms)
if err != nil {
return err
}
@@ -210,7 +210,7 @@ type GenerateParams struct {
GID uint32
// Children is a map of children names for a directory. It should be
// empty for a regular file.
- Children map[string]struct{}
+ Children []string
// SymlinkTarget is the target path of a symlink file, or "" if the file is not a symlink.
SymlinkTarget string
// HashAlgorithms is the algorithms used to hash data.
@@ -242,7 +242,7 @@ func Generate(params *GenerateParams) ([]byte, error) {
// If file is a symlink do not generate root hash for file content.
if params.SymlinkTarget != "" {
- return hashData([]byte(descriptor.String()), params.HashAlgorithms)
+ return hashData(descriptor.encode(), params.HashAlgorithms)
}
layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile)
@@ -315,7 +315,7 @@ func Generate(params *GenerateParams) ([]byte, error) {
numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock()
}
descriptor.RootHash = root
- return hashData([]byte(descriptor.String()), params.HashAlgorithms)
+ return hashData(descriptor.encode(), params.HashAlgorithms)
}
// VerifyParams contains the params used to verify a portion of a file against
@@ -339,7 +339,7 @@ type VerifyParams struct {
GID uint32
// Children is a map of children names for a directory. It should be
// empty for a regular file.
- Children map[string]struct{}
+ Children []string
// SymlinkTarget is the target path of a symlink file, or "" if the file is not a symlink.
SymlinkTarget string
// HashAlgorithms is the algorithms used to hash data.
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index 5d6f8df1b..1447fd139 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -206,112 +206,112 @@ func TestGenerate(t *testing.T) {
data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
- expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27},
+ expectedHash: []byte{78, 38, 225, 107, 61, 246, 26, 6, 71, 163, 254, 97, 112, 200, 87, 232, 190, 87, 231, 160, 119, 124, 61, 229, 49, 126, 90, 223, 134, 51, 77, 182},
},
{
name: "OnePageZeroesSHA256SameFile",
data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
- expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27},
+ expectedHash: []byte{78, 38, 225, 107, 61, 246, 26, 6, 71, 163, 254, 97, 112, 200, 87, 232, 190, 87, 231, 160, 119, 124, 61, 229, 49, 126, 90, 223, 134, 51, 77, 182},
},
{
name: "OnePageZeroesSHA512SeparateFile",
data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
- expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14},
+ expectedHash: []byte{221, 45, 182, 132, 61, 212, 227, 145, 150, 131, 98, 221, 195, 5, 89, 21, 188, 36, 250, 101, 85, 78, 197, 253, 193, 23, 74, 219, 28, 108, 77, 47, 65, 79, 123, 144, 50, 245, 109, 72, 71, 80, 24, 77, 158, 95, 242, 185, 109, 163, 105, 183, 67, 106, 55, 194, 223, 46, 12, 242, 165, 203, 172, 254},
},
{
name: "OnePageZeroesSHA512SameFile",
data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
- expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14},
+ expectedHash: []byte{221, 45, 182, 132, 61, 212, 227, 145, 150, 131, 98, 221, 195, 5, 89, 21, 188, 36, 250, 101, 85, 78, 197, 253, 193, 23, 74, 219, 28, 108, 77, 47, 65, 79, 123, 144, 50, 245, 109, 72, 71, 80, 24, 77, 158, 95, 242, 185, 109, 163, 105, 183, 67, 106, 55, 194, 223, 46, 12, 242, 165, 203, 172, 254},
},
{
name: "MultiplePageZeroesSHA256SeparateFile",
data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
- expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185},
+ expectedHash: []byte{131, 122, 73, 143, 4, 202, 193, 156, 218, 169, 196, 223, 70, 100, 117, 191, 241, 113, 134, 11, 229, 231, 105, 157, 156, 0, 66, 213, 122, 145, 174, 8},
},
{
name: "MultiplePageZeroesSHA256SameFile",
data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
- expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185},
+ expectedHash: []byte{131, 122, 73, 143, 4, 202, 193, 156, 218, 169, 196, 223, 70, 100, 117, 191, 241, 113, 134, 11, 229, 231, 105, 157, 156, 0, 66, 213, 122, 145, 174, 8},
},
{
name: "MultiplePageZeroesSHA512SeparateFile",
data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
- expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81},
+ expectedHash: []byte{211, 48, 232, 110, 240, 51, 99, 241, 123, 138, 42, 76, 94, 86, 59, 200, 3, 246, 137, 148, 189, 226, 111, 103, 146, 29, 12, 218, 40, 182, 33, 99, 193, 163, 238, 26, 184, 13, 165, 187, 68, 173, 139, 9, 208, 59, 0, 192, 180, 50, 221, 35, 43, 119, 194, 16, 64, 84, 116, 63, 158, 195, 194, 226},
},
{
name: "MultiplePageZeroesSHA512SameFile",
data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
- expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81},
+ expectedHash: []byte{211, 48, 232, 110, 240, 51, 99, 241, 123, 138, 42, 76, 94, 86, 59, 200, 3, 246, 137, 148, 189, 226, 111, 103, 146, 29, 12, 218, 40, 182, 33, 99, 193, 163, 238, 26, 184, 13, 165, 187, 68, 173, 139, 9, 208, 59, 0, 192, 180, 50, 221, 35, 43, 119, 194, 16, 64, 84, 116, 63, 158, 195, 194, 226},
},
{
name: "SingleASHA256SeparateFile",
data: []byte{'a'},
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
- expectedHash: []byte{90, 124, 194, 100, 206, 242, 75, 152, 47, 249, 16, 27, 136, 161, 223, 228, 121, 241, 126, 158, 126, 122, 100, 120, 117, 15, 81, 78, 201, 133, 119, 111},
+ expectedHash: []byte{26, 47, 238, 138, 235, 244, 140, 231, 129, 240, 155, 252, 219, 44, 46, 72, 57, 249, 139, 88, 132, 238, 86, 108, 181, 115, 96, 72, 99, 210, 134, 47},
},
{
name: "SingleASHA256SameFile",
data: []byte{'a'},
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
- expectedHash: []byte{90, 124, 194, 100, 206, 242, 75, 152, 47, 249, 16, 27, 136, 161, 223, 228, 121, 241, 126, 158, 126, 122, 100, 120, 117, 15, 81, 78, 201, 133, 119, 111},
+ expectedHash: []byte{26, 47, 238, 138, 235, 244, 140, 231, 129, 240, 155, 252, 219, 44, 46, 72, 57, 249, 139, 88, 132, 238, 86, 108, 181, 115, 96, 72, 99, 210, 134, 47},
},
{
name: "SingleASHA512SeparateFile",
data: []byte{'a'},
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
- expectedHash: []byte{24, 10, 13, 25, 113, 62, 169, 99, 151, 70, 166, 113, 81, 81, 163, 85, 5, 25, 29, 15, 46, 37, 104, 120, 142, 218, 52, 178, 187, 83, 30, 166, 101, 87, 70, 196, 188, 61, 123, 20, 13, 254, 126, 52, 212, 111, 75, 203, 33, 233, 233, 47, 181, 161, 43, 193, 131, 41, 99, 33, 164, 73, 89, 152},
+ expectedHash: []byte{44, 30, 224, 12, 102, 119, 163, 171, 119, 175, 212, 121, 231, 188, 125, 171, 79, 28, 144, 234, 75, 122, 44, 75, 15, 101, 173, 92, 233, 109, 234, 60, 173, 148, 125, 85, 94, 234, 95, 91, 16, 196, 88, 175, 23, 129, 226, 110, 24, 238, 5, 49, 186, 128, 72, 188, 193, 180, 207, 193, 203, 119, 40, 191},
},
{
name: "SingleASHA512SameFile",
data: []byte{'a'},
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
- expectedHash: []byte{24, 10, 13, 25, 113, 62, 169, 99, 151, 70, 166, 113, 81, 81, 163, 85, 5, 25, 29, 15, 46, 37, 104, 120, 142, 218, 52, 178, 187, 83, 30, 166, 101, 87, 70, 196, 188, 61, 123, 20, 13, 254, 126, 52, 212, 111, 75, 203, 33, 233, 233, 47, 181, 161, 43, 193, 131, 41, 99, 33, 164, 73, 89, 152},
+ expectedHash: []byte{44, 30, 224, 12, 102, 119, 163, 171, 119, 175, 212, 121, 231, 188, 125, 171, 79, 28, 144, 234, 75, 122, 44, 75, 15, 101, 173, 92, 233, 109, 234, 60, 173, 148, 125, 85, 94, 234, 95, 91, 16, 196, 88, 175, 23, 129, 226, 110, 24, 238, 5, 49, 186, 128, 72, 188, 193, 180, 207, 193, 203, 119, 40, 191},
},
{
name: "OnePageASHA256SeparateFile",
data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
- expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1},
+ expectedHash: []byte{166, 254, 83, 46, 241, 111, 18, 47, 79, 6, 181, 197, 176, 143, 211, 204, 53, 5, 245, 134, 172, 95, 97, 131, 236, 132, 197, 138, 123, 78, 43, 13},
},
{
name: "OnePageASHA256SameFile",
data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
- expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1},
+ expectedHash: []byte{166, 254, 83, 46, 241, 111, 18, 47, 79, 6, 181, 197, 176, 143, 211, 204, 53, 5, 245, 134, 172, 95, 97, 131, 236, 132, 197, 138, 123, 78, 43, 13},
},
{
name: "OnePageASHA512SeparateFile",
data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
- expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145},
+ expectedHash: []byte{23, 69, 6, 79, 39, 232, 90, 246, 62, 55, 4, 229, 47, 36, 230, 24, 233, 47, 55, 36, 26, 139, 196, 78, 242, 12, 194, 77, 109, 81, 151, 188, 63, 201, 127, 235, 81, 214, 91, 200, 19, 232, 240, 14, 197, 1, 99, 224, 18, 213, 203, 242, 44, 102, 25, 62, 90, 189, 106, 107, 129, 61, 115, 39},
},
{
name: "OnePageASHA512SameFile",
data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
- expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145},
+ expectedHash: []byte{23, 69, 6, 79, 39, 232, 90, 246, 62, 55, 4, 229, 47, 36, 230, 24, 233, 47, 55, 36, 26, 139, 196, 78, 242, 12, 194, 77, 109, 81, 151, 188, 63, 201, 127, 235, 81, 214, 91, 200, 19, 232, 240, 14, 197, 1, 99, 224, 18, 213, 203, 242, 44, 102, 25, 62, 90, 189, 106, 107, 129, 61, 115, 39},
},
}
@@ -324,7 +324,7 @@ func TestGenerate(t *testing.T) {
Mode: defaultMode,
UID: defaultUID,
GID: defaultGID,
- Children: make(map[string]struct{}),
+ Children: []string{},
HashAlgorithms: tc.hashAlgorithms,
TreeReader: &tree,
TreeWriter: &tree,
@@ -366,7 +366,7 @@ func prepareVerify(t *testing.T, dataSize int64, hashAlgorithm int, dataAndTreeI
Mode: defaultMode,
UID: defaultUID,
GID: defaultGID,
- Children: make(map[string]struct{}),
+ Children: []string{},
HashAlgorithms: hashAlgorithm,
TreeReader: &tree,
TreeWriter: &tree,
@@ -398,7 +398,7 @@ func prepareVerify(t *testing.T, dataSize int64, hashAlgorithm int, dataAndTreeI
Mode: defaultMode,
UID: defaultUID,
GID: defaultGID,
- Children: make(map[string]struct{}),
+ Children: []string{},
HashAlgorithms: hashAlgorithm,
ReadOffset: verifyStart,
ReadSize: verifySize,
@@ -627,7 +627,7 @@ func TestVerifyModifiedChildren(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
_, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
- params.Children["abc"] = struct{}{}
+ params.Children = append(params.Children, "abc")
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
}
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index fdeee3a5f..4829ae7ce 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -36,17 +36,22 @@ var (
// new metric after initialization.
ErrInitializationDone = errors.New("metric cannot be created after initialization is complete")
- // createdSentryMetrics indicates that the sentry metrics are created.
- createdSentryMetrics = false
-
// WeirdnessMetric is a metric with fields created to track the number
// of weird occurrences such as time fallback, partial_result, vsyscall
// count, watchdog startup timeouts and stuck tasks.
- WeirdnessMetric *Uint64Metric
+ WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.",
+ Field{
+ name: "weirdness_type",
+ allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"},
+ })
// SuspiciousOperationsMetric is a metric with fields created to detect
// operations such as opening an executable file to write from a gofer.
- SuspiciousOperationsMetric *Uint64Metric
+ SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.",
+ Field{
+ name: "operation_type",
+ allowedValues: []string{"opened_write_execute_file"},
+ })
)
// Uint64Metric encapsulates a uint64 that represents some kind of metric to be
@@ -84,17 +89,21 @@ var (
// Precondition:
// * All metrics are registered.
// * Initialize/Disable has not been called.
-func Initialize() {
+func Initialize() error {
if initialized {
- panic("Initialize/Disable called more than once")
+ return errors.New("metric.Initialize called after metric.Initialize or metric.Disable")
}
- initialized = true
m := pb.MetricRegistration{}
for _, v := range allMetrics.m {
m.Metrics = append(m.Metrics, v.metadata)
}
- eventchannel.Emit(&m)
+ if err := eventchannel.Emit(&m); err != nil {
+ return fmt.Errorf("unable to emit metric initialize event: %w", err)
+ }
+
+ initialized = true
+ return nil
}
// Disable sends an empty metric registration event over the event channel,
@@ -103,16 +112,18 @@ func Initialize() {
// Precondition:
// * All metrics are registered.
// * Initialize/Disable has not been called.
-func Disable() {
+func Disable() error {
if initialized {
- panic("Initialize/Disable called more than once")
+ return errors.New("metric.Disable called after metric.Initialize or metric.Disable")
}
- initialized = true
m := pb.MetricRegistration{}
if err := eventchannel.Emit(&m); err != nil {
- panic("unable to emit metric disable event: " + err.Error())
+ return fmt.Errorf("unable to emit metric disable event: %w", err)
}
+
+ initialized = true
+ return nil
}
type customUint64Metric struct {
@@ -165,8 +176,8 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met
}
// Metrics can exist without fields.
- if len(fields) > 1 {
- panic("Sentry metrics support at most one field")
+ if l := len(fields); l > 1 {
+ return fmt.Errorf("%d fields provided, must be <= 1", l)
}
for _, field := range fields {
@@ -182,7 +193,7 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met
// without fields and panics if it returns an error.
func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) {
if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil {
- panic(fmt.Sprintf("Unable to register metric %q: %v", name, err))
+ panic(fmt.Sprintf("Unable to register metric %q: %s", name, err))
}
}
@@ -209,7 +220,7 @@ func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, desc
func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric {
m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...)
if err != nil {
- panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ panic(fmt.Sprintf("Unable to create metric %q: %s", name, err))
}
return m
}
@@ -219,7 +230,7 @@ func MustCreateNewUint64Metric(name string, sync bool, description string, field
func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric {
m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description)
if err != nil {
- panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ panic(fmt.Sprintf("Unable to create metric %q: %s", name, err))
}
return m
}
@@ -354,7 +365,7 @@ func EmitMetricUpdate() {
m.Metrics = append(m.Metrics, &pb.MetricValue{
Name: k,
- Value: &pb.MetricValue_Uint64Value{t},
+ Value: &pb.MetricValue_Uint64Value{Uint64Value: t},
})
case map[string]uint64:
for fieldValue, metricValue := range t {
@@ -369,7 +380,7 @@ func EmitMetricUpdate() {
m.Metrics = append(m.Metrics, &pb.MetricValue{
Name: k,
FieldValues: []string{fieldValue},
- Value: &pb.MetricValue_Uint64Value{metricValue},
+ Value: &pb.MetricValue_Uint64Value{Uint64Value: metricValue},
})
}
}
@@ -390,26 +401,7 @@ func EmitMetricUpdate() {
}
}
- eventchannel.Emit(&m)
-}
-
-// CreateSentryMetrics creates the sentry metrics during kernel initialization.
-func CreateSentryMetrics() {
- if createdSentryMetrics {
- return
+ if err := eventchannel.Emit(&m); err != nil {
+ log.Warningf("Unable to emit metrics: %s", err)
}
-
- createdSentryMetrics = true
-
- WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.",
- Field{
- name: "weirdness_type",
- allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"},
- })
-
- SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.",
- Field{
- name: "operation_type",
- allowedValues: []string{"opened_write_execute_file"},
- })
}
diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go
index c71dfd460..1b4a9e73a 100644
--- a/pkg/metric/metric_test.go
+++ b/pkg/metric/metric_test.go
@@ -48,6 +48,8 @@ func (s *sliceEmitter) Reset() {
var emitter sliceEmitter
func init() {
+ reset()
+
eventchannel.AddEmitter(&emitter)
}
@@ -77,7 +79,9 @@ func TestInitialize(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Initialize()
+ if err := Initialize(); err != nil {
+ t.Fatalf("Initialize(): %s", err)
+ }
if len(emitter) != 1 {
t.Fatalf("Initialize emitted %d events want 1", len(emitter))
@@ -149,7 +153,9 @@ func TestDisable(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Disable()
+ if err := Disable(); err != nil {
+ t.Fatalf("Disable(): %s", err)
+ }
if len(emitter) != 1 {
t.Fatalf("Initialize emitted %d events want 1", len(emitter))
@@ -178,7 +184,9 @@ func TestEmitMetricUpdate(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Initialize()
+ if err := Initialize(); err != nil {
+ t.Fatalf("Initialize(): %s", err)
+ }
// Don't care about the registration metrics.
emitter.Reset()
@@ -270,7 +278,9 @@ func TestEmitMetricUpdateWithFields(t *testing.T) {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- Initialize()
+ if err := Initialize(); err != nil {
+ t.Fatalf("Initialize(): %s", err)
+ }
// Don't care about the registration metrics.
emitter.Reset()
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index b2291ef97..2b22b2203 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -22,6 +22,9 @@ go_library(
"version.go",
],
deps = [
+ "//pkg/abi/linux/errno",
+ "//pkg/errors",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/fdchannel",
"//pkg/flipcall",
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 758e11b13..161b451cc 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -23,6 +23,9 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
+ "gvisor.dev/gvisor/pkg/errors"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
)
@@ -62,6 +65,45 @@ func newErr(err error) *Rlerror {
return &Rlerror{Error: uint32(ExtractErrno(err))}
}
+// ExtractLinuxerrErrno extracts a *errors.Error from a error, best effort.
+// TODO(b/34162363): Merge this with ExtractErrno.
+func ExtractLinuxerrErrno(err error) *errors.Error {
+ switch err {
+ case os.ErrNotExist:
+ return linuxerr.ENOENT
+ case os.ErrExist:
+ return linuxerr.EEXIST
+ case os.ErrPermission:
+ return linuxerr.EACCES
+ case os.ErrInvalid:
+ return linuxerr.EINVAL
+ }
+
+ // Attempt to unwrap.
+ switch e := err.(type) {
+ case *errors.Error:
+ return e
+ case unix.Errno:
+ return linuxerr.ErrorFromErrno(errno.Errno(e))
+ case *os.PathError:
+ return ExtractLinuxerrErrno(e.Err)
+ case *os.SyscallError:
+ return ExtractLinuxerrErrno(e.Err)
+ case *os.LinkError:
+ return ExtractLinuxerrErrno(e.Err)
+ }
+
+ // Default case.
+ log.Warningf("unknown error: %v", err)
+ return linuxerr.EIO
+}
+
+// newErrFromLinuxerr returns an Rlerror from the linuxerr list.
+// TODO(b/34162363): Merge this with newErr.
+func newErrFromLinuxerr(err error) *Rlerror {
+ return &Rlerror{Error: uint32(ExtractLinuxerrErrno(err).Errno())}
+}
+
// handler is implemented for server-handled messages.
//
// See server.go for call information.
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index ff1172ed6..241ab44ef 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -19,7 +19,8 @@ import (
"runtime/debug"
"sync/atomic"
- "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdchannel"
"gvisor.dev/gvisor/pkg/flipcall"
@@ -483,7 +484,7 @@ func (cs *connState) lookupChannel(id uint32) *channel {
func (cs *connState) handle(m message) (r message) {
if !cs.reqGate.Enter() {
// connState.stop() has been called; the connection is shutting down.
- r = newErr(unix.ECONNRESET)
+ r = newErrFromLinuxerr(linuxerr.ECONNRESET)
return
}
defer func() {
@@ -498,15 +499,23 @@ func (cs *connState) handle(m message) (r message) {
// Wrap in an EFAULT error; we don't really have a
// better way to describe this kind of error. It will
// usually manifest as a result of the test framework.
- r = newErr(unix.EFAULT)
+ r = newErrFromLinuxerr(linuxerr.EFAULT)
}
}()
if handler, ok := m.(handler); ok {
// Call the message handler.
r = handler.handle(cs)
+ // TODO(b/34162363):This is only here to make sure the server works with
+ // only linuxerr Errors, as the handlers work with both client and server.
+ // It will be removed a followup, when all the unix.Errno errors are
+ // replaced with linuxerr.
+ if rlError, ok := r.(*Rlerror); ok {
+ e := linuxerr.ErrorFromErrno(errno.Errno(rlError.Error))
+ r = newErrFromLinuxerr(e)
+ }
} else {
// Produce an ENOSYS error.
- r = newErr(unix.ENOSYS)
+ r = newErrFromLinuxerr(linuxerr.ENOSYS)
}
return
}
@@ -553,7 +562,7 @@ func (cs *connState) handleRequest() bool {
// If it's not a connection error, but some other protocol error,
// we can send a response immediately.
cs.sendMu.Lock()
- err := send(cs.conn, tag, newErr(err))
+ err := send(cs.conn, tag, newErrFromLinuxerr(err))
cs.sendMu.Unlock()
if err != nil {
log.Debugf("p9.send: %v", err)
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 4aecb8007..1bbcae045 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -261,8 +261,8 @@ func (l *LeakMode) Get() interface{} {
}
// String implements flag.Value.
-func (l *LeakMode) String() string {
- switch *l {
+func (l LeakMode) String() string {
+ switch l {
case UninitializedLeakChecking:
return "uninitialized"
case NoLeakChecking:
@@ -272,7 +272,7 @@ func (l *LeakMode) String() string {
case LeaksLogTraces:
return "log-traces"
}
- panic(fmt.Sprintf("invalid ref leak mode %d", *l))
+ panic(fmt.Sprintf("invalid ref leak mode %d", l))
}
// leakMode stores the current mode for the reference leak checker.
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 41dfd0bf9..f63af8b76 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -254,6 +254,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
return
}
+var sentryXCR0 = xgetbv(0)
+
// start is the CPU entrypoint.
//
// This is called from the Start asm stub (see entry_amd64.go); on return the
@@ -265,24 +267,10 @@ func start(c *CPU) {
WriteGS(kernelAddr(c.kernelEntry))
WriteFS(uintptr(c.registers.Fs_base))
- // Initialize floating point.
- //
- // Note that on skylake, the valid XCR0 mask reported seems to be 0xff.
- // This breaks down as:
- //
- // bit0 - x87
- // bit1 - SSE
- // bit2 - AVX
- // bit3-4 - MPX
- // bit5-7 - AVX512
- //
- // For some reason, enabled MPX & AVX512 on platforms that report them
- // seems to be cause a general protection fault. (Maybe there are some
- // virtualization issues and these aren't exported to the guest cpuid.)
- // This needs further investigation, but we can limit the floating
- // point operations to x87, SSE & AVX for now.
fninit()
- xsetbv(0, validXCR0Mask&0x7)
+ // Need to sync XCR0 with the host, because xsave and xrstor can be
+ // called from different contexts.
+ xsetbv(0, sentryXCR0)
// Set the syscall target.
wrmsr(_MSR_LSTAR, kernelFunc(sysenter))
diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD
index b77c40279..db5787302 100644
--- a/pkg/safecopy/BUILD
+++ b/pkg/safecopy/BUILD
@@ -18,6 +18,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
"//pkg/syserror",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index efbc2ddc1..2365b2c0d 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -20,6 +20,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
)
// maxRegisterSize is the maximum register size used in memcpy and memclr. It
@@ -342,12 +343,7 @@ func errorFromFaultSignal(addr uintptr, sig int32) error {
// handler however, and if this is function is being used externally then the
// same courtesy is expected.
func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
- var sa struct {
- handler uintptr
- flags uint64
- restorer uintptr
- mask uint64
- }
+ var sa linux.SigAction
const maskLen = 8
// Get the existing signal handler information, and save the current
@@ -358,14 +354,14 @@ func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) e
}
// Fail if there isn't a previous handler.
- if sa.handler == 0 {
+ if sa.Handler == 0 {
return fmt.Errorf("previous handler for signal %x isn't set", sig)
}
- *previous = sa.handler
+ *previous = uintptr(sa.Handler)
// Install our own handler.
- sa.handler = handler
+ sa.Handler = uint64(handler)
if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
return e
}
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index daea51c4d..8ffa1db37 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -36,14 +36,10 @@ const (
// Install generates BPF code based on the set of syscalls provided. It only
// allows syscalls that conform to the specification. Syscalls that violate the
-// specification will trigger RET_KILL_PROCESS, except for the cases below.
-//
-// RET_TRAP is used in violations, instead of RET_KILL_PROCESS, in the
-// following cases:
-// 1. Kernel doesn't support RET_KILL_PROCESS: RET_KILL_THREAD only kills the
-// offending thread and often keeps the sentry hanging.
-// 2. Debug: RET_TRAP generates a panic followed by a stack trace which is
-// much easier to debug then RET_KILL_PROCESS which can't be caught.
+// specification will trigger RET_KILL_PROCESS. If RET_KILL_PROCESS is not
+// supported, violations will trigger RET_TRAP instead. RET_KILL_THREAD is not
+// used because it only kills the offending thread and often keeps the sentry
+// hanging.
//
// Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored,
// making it possible for the process to continue running after a violation.
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index c9c52530d..068a0c8d9 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -14,12 +14,8 @@ go_library(
"arch_x86.go",
"arch_x86_impl.go",
"auxv.go",
- "signal.go",
- "signal_act.go",
"signal_amd64.go",
"signal_arm64.go",
- "signal_info.go",
- "signal_stack.go",
"stack.go",
"stack_unsafe.go",
"syscalls_amd64.go",
@@ -32,6 +28,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/cpuid",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 290863ee6..c9393b091 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -134,21 +134,13 @@ type Context interface {
// RegisterMap returns a map of all registers.
RegisterMap() (map[string]uintptr, error)
- // NewSignalAct returns a new object that is equivalent to struct sigaction
- // in the guest architecture.
- NewSignalAct() NativeSignalAct
-
- // NewSignalStack returns a new object that is equivalent to stack_t in the
- // guest architecture.
- NewSignalStack() NativeSignalStack
-
// SignalSetup modifies the context in preparation for handling the
// given signal.
//
// st is the stack where the signal handler frame should be
// constructed.
//
- // act is the SignalAct that specifies how this signal is being
+ // act is the SigAction that specifies how this signal is being
// handled.
//
// info is the SignalInfo of the signal being delivered.
@@ -157,7 +149,7 @@ type Context interface {
// stack is not going to be used).
//
// sigset is the signal mask before entering the signal handler.
- SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error
+ SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error
// SignalRestore restores context after returning from a signal
// handler.
@@ -167,7 +159,7 @@ type Context interface {
// rt is true if SignalRestore is being entered from rt_sigreturn and
// false if SignalRestore is being entered from sigreturn.
// SignalRestore returns the thread's new signal mask.
- SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error)
+ SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error)
// CPUIDEmulate emulates a CPUID instruction according to current register state.
CPUIDEmulate(l log.Logger)
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 08789f517..7def71fef 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
@@ -237,7 +238,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
}
return s.PtraceGetRegs(dst)
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
@@ -250,7 +251,7 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int,
}
return s.PtraceSetRegs(src)
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index e8e52d3a8..d13e12f8c 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -23,6 +23,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
@@ -361,7 +362,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
case _NT_X86_XSTATE:
return s.fpState.PtraceGetXstateRegs(dst, maxlen, s.FeatureSet)
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
@@ -378,7 +379,7 @@ func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int,
case _NT_X86_XSTATE:
return s.fpState.PtraceSetXstateRegs(src, maxlen, s.FeatureSet)
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
deleted file mode 100644
index 67d7edf68..000000000
--- a/pkg/sentry/arch/signal.go
+++ /dev/null
@@ -1,276 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arch
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/hostarch"
-)
-
-// SignalAct represents the action that should be taken when a signal is
-// delivered, and is equivalent to struct sigaction.
-//
-// +marshal
-// +stateify savable
-type SignalAct struct {
- Handler uint64
- Flags uint64
- Restorer uint64 // Only used on amd64.
- Mask linux.SignalSet
-}
-
-// SerializeFrom implements NativeSignalAct.SerializeFrom.
-func (s *SignalAct) SerializeFrom(other *SignalAct) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalAct.DeserializeTo.
-func (s *SignalAct) DeserializeTo(other *SignalAct) {
- *other = *s
-}
-
-// SignalStack represents information about a user stack, and is equivalent to
-// stack_t.
-//
-// +marshal
-// +stateify savable
-type SignalStack struct {
- Addr uint64
- Flags uint32
- _ uint32
- Size uint64
-}
-
-// SerializeFrom implements NativeSignalStack.SerializeFrom.
-func (s *SignalStack) SerializeFrom(other *SignalStack) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalStack.DeserializeTo.
-func (s *SignalStack) DeserializeTo(other *SignalStack) {
- *other = *s
-}
-
-// SignalInfo represents information about a signal being delivered, and is
-// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
-//
-// +marshal
-// +stateify savable
-type SignalInfo struct {
- Signo int32 // Signal number
- Errno int32 // Errno value
- Code int32 // Signal code
- _ uint32
-
- // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
- // are accessed through methods.
- //
- // For reference, here is the definition of _sifields: (_sigfault._trapno,
- // which does not exist on x86, omitted for clarity)
- //
- // union {
- // int _pad[SI_PAD_SIZE];
- //
- // /* kill() */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // } _kill;
- //
- // /* POSIX.1b timers */
- // struct {
- // __kernel_timer_t _tid; /* timer id */
- // int _overrun; /* overrun count */
- // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
- // sigval_t _sigval; /* same as below */
- // int _sys_private; /* not to be passed to user */
- // } _timer;
- //
- // /* POSIX.1b signals */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // sigval_t _sigval;
- // } _rt;
- //
- // /* SIGCHLD */
- // struct {
- // __kernel_pid_t _pid; /* which child */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // int _status; /* exit code */
- // __ARCH_SI_CLOCK_T _utime;
- // __ARCH_SI_CLOCK_T _stime;
- // } _sigchld;
- //
- // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
- // struct {
- // void *_addr; /* faulting insn/memory ref. */
- // short _addr_lsb; /* LSB of the reported address */
- // } _sigfault;
- //
- // /* SIGPOLL */
- // struct {
- // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
- // int _fd;
- // } _sigpoll;
- //
- // /* SIGSYS */
- // struct {
- // void *_call_addr; /* calling user insn */
- // int _syscall; /* triggering system call number */
- // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
- // } _sigsys;
- // } _sifields;
- //
- // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
- // bytes.
- Fields [128 - 16]byte
-}
-
-// FixSignalCodeForUser fixes up si_code.
-//
-// The si_code we get from Linux may contain the kernel-specific code in the
-// top 16 bits if it's positive (e.g., from ptrace). Linux's
-// copy_siginfo_to_user does
-// err |= __put_user((short)from->si_code, &to->si_code);
-// to mask out those bits and we need to do the same.
-func (s *SignalInfo) FixSignalCodeForUser() {
- if s.Code > 0 {
- s.Code &= 0x0000ffff
- }
-}
-
-// PID returns the si_pid field.
-func (s *SignalInfo) PID() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetPID mutates the si_pid field.
-func (s *SignalInfo) SetPID(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// UID returns the si_uid field.
-func (s *SignalInfo) UID() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetUID mutates the si_uid field.
-func (s *SignalInfo) SetUID(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
-func (s *SignalInfo) Sigval() uint64 {
- return hostarch.ByteOrder.Uint64(s.Fields[8:16])
-}
-
-// SetSigval mutates the sigval field.
-func (s *SignalInfo) SetSigval(val uint64) {
- hostarch.ByteOrder.PutUint64(s.Fields[8:16], val)
-}
-
-// TimerID returns the si_timerid field.
-func (s *SignalInfo) TimerID() linux.TimerID {
- return linux.TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetTimerID sets the si_timerid field.
-func (s *SignalInfo) SetTimerID(val linux.TimerID) {
- hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Overrun returns the si_overrun field.
-func (s *SignalInfo) Overrun() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetOverrun sets the si_overrun field.
-func (s *SignalInfo) SetOverrun(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Addr returns the si_addr field.
-func (s *SignalInfo) Addr() uint64 {
- return hostarch.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetAddr sets the si_addr field.
-func (s *SignalInfo) SetAddr(val uint64) {
- hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Status returns the si_status field.
-func (s *SignalInfo) Status() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetStatus mutates the si_status field.
-func (s *SignalInfo) SetStatus(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// CallAddr returns the si_call_addr field.
-func (s *SignalInfo) CallAddr() uint64 {
- return hostarch.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetCallAddr mutates the si_call_addr field.
-func (s *SignalInfo) SetCallAddr(val uint64) {
- hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Syscall returns the si_syscall field.
-func (s *SignalInfo) Syscall() int32 {
- return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetSyscall mutates the si_syscall field.
-func (s *SignalInfo) SetSyscall(val int32) {
- hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// Arch returns the si_arch field.
-func (s *SignalInfo) Arch() uint32 {
- return hostarch.ByteOrder.Uint32(s.Fields[12:16])
-}
-
-// SetArch mutates the si_arch field.
-func (s *SignalInfo) SetArch(val uint32) {
- hostarch.ByteOrder.PutUint32(s.Fields[12:16], val)
-}
-
-// Band returns the si_band field.
-func (s *SignalInfo) Band() int64 {
- return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8]))
-}
-
-// SetBand mutates the si_band field.
-func (s *SignalInfo) SetBand(val int64) {
- // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
- // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
- // `int`. See siginfo.h.
- hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
-}
-
-// FD returns the si_fd field.
-func (s *SignalInfo) FD() uint32 {
- return hostarch.ByteOrder.Uint32(s.Fields[8:12])
-}
-
-// SetFD mutates the si_fd field.
-func (s *SignalInfo) SetFD(val uint32) {
- hostarch.ByteOrder.PutUint32(s.Fields[8:12], val)
-}
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
deleted file mode 100644
index d3e2324a8..000000000
--- a/pkg/sentry/arch/signal_act.go
+++ /dev/null
@@ -1,83 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arch
-
-import "gvisor.dev/gvisor/pkg/marshal"
-
-// Special values for SignalAct.Handler.
-const (
- // SignalActDefault is SIG_DFL and specifies that the default behavior for
- // a signal should be taken.
- SignalActDefault = 0
-
- // SignalActIgnore is SIG_IGN and specifies that a signal should be
- // ignored.
- SignalActIgnore = 1
-)
-
-// Available signal flags.
-const (
- SignalFlagNoCldStop = 0x00000001
- SignalFlagNoCldWait = 0x00000002
- SignalFlagSigInfo = 0x00000004
- SignalFlagRestorer = 0x04000000
- SignalFlagOnStack = 0x08000000
- SignalFlagRestart = 0x10000000
- SignalFlagInterrupt = 0x20000000
- SignalFlagNoDefer = 0x40000000
- SignalFlagResetHandler = 0x80000000
-)
-
-// IsSigInfo returns true iff this handle expects siginfo.
-func (s SignalAct) IsSigInfo() bool {
- return s.Flags&SignalFlagSigInfo != 0
-}
-
-// IsNoDefer returns true iff this SignalAct has the NoDefer flag set.
-func (s SignalAct) IsNoDefer() bool {
- return s.Flags&SignalFlagNoDefer != 0
-}
-
-// IsRestart returns true iff this SignalAct has the Restart flag set.
-func (s SignalAct) IsRestart() bool {
- return s.Flags&SignalFlagRestart != 0
-}
-
-// IsResetHandler returns true iff this SignalAct has the ResetHandler flag set.
-func (s SignalAct) IsResetHandler() bool {
- return s.Flags&SignalFlagResetHandler != 0
-}
-
-// IsOnStack returns true iff this SignalAct has the OnStack flag set.
-func (s SignalAct) IsOnStack() bool {
- return s.Flags&SignalFlagOnStack != 0
-}
-
-// HasRestorer returns true iff this SignalAct has the Restorer flag set.
-func (s SignalAct) HasRestorer() bool {
- return s.Flags&SignalFlagRestorer != 0
-}
-
-// NativeSignalAct is a type that is equivalent to struct sigaction in the
-// guest architecture.
-type NativeSignalAct interface {
- marshal.Marshallable
-
- // SerializeFrom copies the data in the host SignalAct s into this object.
- SerializeFrom(s *SignalAct)
-
- // DeserializeTo copies the data in this object into the host SignalAct s.
- DeserializeTo(s *SignalAct)
-}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index 082ed92b1..58e28dbba 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -76,21 +76,11 @@ const (
type UContext64 struct {
Flags uint64
Link uint64
- Stack SignalStack
+ Stack linux.SignalStack
MContext SignalContext64
Sigset linux.SignalSet
}
-// NewSignalAct implements Context.NewSignalAct.
-func (c *context64) NewSignalAct() NativeSignalAct {
- return &SignalAct{}
-}
-
-// NewSignalStack implements Context.NewSignalStack.
-func (c *context64) NewSignalStack() NativeSignalStack {
- return &SignalStack{}
-}
-
// From Linux 'arch/x86/include/uapi/asm/sigcontext.h' the following is the
// size of the magic cookie at the end of the xsave frame.
//
@@ -110,7 +100,7 @@ func (c *context64) fpuFrameSize() (size int, useXsave bool) {
// SignalSetup implements Context.SignalSetup. (Compare to Linux's
// arch/x86/kernel/signal.c:__setup_rt_frame().)
-func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error {
sp := st.Bottom
// "The 128-byte area beyond the location pointed to by %rsp is considered
@@ -187,7 +177,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// Prior to proceeding, figure out if the frame will exhaust the range
// for the signal stack. This is not allowed, and should immediately
// force signal delivery (reverting to the default handler).
- if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) {
return unix.EFAULT
}
@@ -203,7 +193,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
return err
}
ucAddr := st.Bottom
- if act.HasRestorer() {
+ if act.Flags&linux.SA_RESTORER != 0 {
// Push the restorer return address.
// Note that this doesn't need to be popped.
if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil {
@@ -237,15 +227,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// SignalRestore implements Context.SignalRestore. (Compare to Linux's
// arch/x86/kernel/signal.c:sys_rt_sigreturn().)
-func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
- var info SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
// Restore registers.
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index da71fb873..80df90076 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -61,7 +61,7 @@ type FpsimdContext struct {
type UContext64 struct {
Flags uint64
Link uint64
- Stack SignalStack
+ Stack linux.SignalStack
Sigset linux.SignalSet
// glibc uses a 1024-bit sigset_t
_pad [120]byte // (1024 - 64) / 8 = 120
@@ -71,18 +71,8 @@ type UContext64 struct {
MContext SignalContext64
}
-// NewSignalAct implements Context.NewSignalAct.
-func (c *context64) NewSignalAct() NativeSignalAct {
- return &SignalAct{}
-}
-
-// NewSignalStack implements Context.NewSignalStack.
-func (c *context64) NewSignalStack() NativeSignalStack {
- return &SignalStack{}
-}
-
// SignalSetup implements Context.SignalSetup.
-func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *linux.SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error {
sp := st.Bottom
// Construct the UContext64 now since we need its size.
@@ -114,7 +104,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// Prior to proceeding, figure out if the frame will exhaust the range
// for the signal stack. This is not allowed, and should immediately
// force signal delivery (reverting to the default handler).
- if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) {
return unix.EFAULT
}
@@ -137,7 +127,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
c.Regs.Regs[0] = uint64(info.Signo)
c.Regs.Regs[1] = uint64(infoAddr)
c.Regs.Regs[2] = uint64(ucAddr)
- c.Regs.Regs[30] = uint64(act.Restorer)
+ c.Regs.Regs[30] = act.Restorer
// Save the thread's floating point state.
c.sigFPState = append(c.sigFPState, c.fpState)
@@ -147,15 +137,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
}
// SignalRestore implements Context.SignalRestore.
-func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) {
// Copy out the stack frame.
var uc UContext64
if _, err := uc.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
- var info SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(st, StackBottomMagic); err != nil {
- return 0, SignalStack{}, err
+ return 0, linux.SignalStack{}, err
}
// Restore registers.
diff --git a/pkg/sentry/arch/signal_info.go b/pkg/sentry/arch/signal_info.go
deleted file mode 100644
index f93ee8b46..000000000
--- a/pkg/sentry/arch/signal_info.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arch
-
-// Possible values for SignalInfo.Code. These values originate from the Linux
-// kernel's include/uapi/asm-generic/siginfo.h.
-const (
- // SignalInfoUser (properly SI_USER) indicates that a signal was sent from
- // a kill() or raise() syscall.
- SignalInfoUser = 0
-
- // SignalInfoKernel (properly SI_KERNEL) indicates that the signal was sent
- // by the kernel.
- SignalInfoKernel = 0x80
-
- // SignalInfoTimer (properly SI_TIMER) indicates that the signal was sent
- // by an expired timer.
- SignalInfoTimer = -2
-
- // SignalInfoTkill (properly SI_TKILL) indicates that the signal was sent
- // from a tkill() or tgkill() syscall.
- SignalInfoTkill = -6
-
- // CLD_* codes are only meaningful for SIGCHLD.
-
- // CLD_EXITED indicates that a task exited.
- CLD_EXITED = 1
-
- // CLD_KILLED indicates that a task was killed by a signal.
- CLD_KILLED = 2
-
- // CLD_DUMPED indicates that a task was killed by a signal and then dumped
- // core.
- CLD_DUMPED = 3
-
- // CLD_TRAPPED indicates that a task was stopped by ptrace.
- CLD_TRAPPED = 4
-
- // CLD_STOPPED indicates that a thread group completed a group stop.
- CLD_STOPPED = 5
-
- // CLD_CONTINUED indicates that a group-stopped thread group was continued.
- CLD_CONTINUED = 6
-
- // SYS_* codes are only meaningful for SIGSYS.
-
- // SYS_SECCOMP indicates that a signal originates from seccomp.
- SYS_SECCOMP = 1
-
- // TRAP_* codes are only meaningful for SIGTRAP.
-
- // TRAP_BRKPT indicates a breakpoint trap.
- TRAP_BRKPT = 1
-)
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
deleted file mode 100644
index c732c7503..000000000
--- a/pkg/sentry/arch/signal_stack.go
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build 386 amd64 arm64
-
-package arch
-
-import (
- "gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/marshal"
-)
-
-const (
- // SignalStackFlagOnStack is possible set on return from getaltstack,
- // in order to indicate that the thread is currently on the alt stack.
- SignalStackFlagOnStack = 1
-
- // SignalStackFlagDisable is a flag to indicate the stack is disabled.
- SignalStackFlagDisable = 2
-)
-
-// IsEnabled returns true iff this signal stack is marked as enabled.
-func (s SignalStack) IsEnabled() bool {
- return s.Flags&SignalStackFlagDisable == 0
-}
-
-// Top returns the stack's top address.
-func (s SignalStack) Top() hostarch.Addr {
- return hostarch.Addr(s.Addr + s.Size)
-}
-
-// SetOnStack marks this signal stack as in use.
-//
-// Note that there is no corresponding ClearOnStack, and that this should only
-// be called on copies that are serialized to userspace.
-func (s *SignalStack) SetOnStack() {
- s.Flags |= SignalStackFlagOnStack
-}
-
-// Contains checks if the stack pointer is within this stack.
-func (s *SignalStack) Contains(sp hostarch.Addr) bool {
- return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size)
-}
-
-// NativeSignalStack is a type that is equivalent to stack_t in the guest
-// architecture.
-type NativeSignalStack interface {
- marshal.Marshallable
-
- // SerializeFrom copies the data in the host SignalStack s into this
- // object.
- SerializeFrom(s *SignalStack)
-
- // DeserializeTo copies the data in this object into the host SignalStack
- // s.
- DeserializeTo(s *SignalStack)
-}
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 65a794c7c..85e3515af 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -45,7 +45,7 @@ type Stack struct {
}
// scratchBufLen is the default length of Stack.scratchBuf. The
-// largest structs the stack regularly serializes are arch.SignalInfo
+// largest structs the stack regularly serializes are linux.SignalInfo
// and arch.UContext64. We'll set the default size as the larger of
// the two, arch.UContext64.
var scratchBufLen = (*UContext64)(nil).SizeBytes()
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 367849e75..221e98a01 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -99,6 +99,9 @@ type ExecArgs struct {
// PIDNamespace is the pid namespace for the process being executed.
PIDNamespace *kernel.PIDNamespace
+
+ // Limits is the limit set for the process being executed.
+ Limits *limits.LimitSet
}
// String prints the arguments as a string.
@@ -151,6 +154,10 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
if pidns == nil {
pidns = proc.Kernel.RootPIDNamespace()
}
+ limitSet := args.Limits
+ if limitSet == nil {
+ limitSet = limits.NewLimitSet()
+ }
initArgs := kernel.CreateProcessArgs{
Filename: args.Filename,
Argv: args.Argv,
@@ -161,7 +168,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
Credentials: creds,
FDTable: fdTable,
Umask: 0022,
- Limits: limits.NewLimitSet(),
+ Limits: limitSet,
MaxSymlinkTraversals: linux.MaxSymlinkTraversals,
UTSNamespace: proc.Kernel.RootUTSNamespace(),
IPCNamespace: proc.Kernel.RootIPCNamespace(),
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
index 8b38d574d..37229e7ba 100644
--- a/pkg/sentry/devices/tundev/BUILD
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/sentry/arch",
"//pkg/sentry/fsimpl/devtmpfs",
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index a12eeb8e7..4ef91a600 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -18,6 +18,7 @@ package tundev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
@@ -81,7 +82,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
}
stack, ok := t.NetworkContext().(*netstack.Stack)
if !ok {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
var req linux.IFReq
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index 0dc100f9b..74adbfa55 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -48,6 +48,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/p9",
@@ -110,6 +111,7 @@ go_test(
deps = [
":fs",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go
index b90f7c1be..4c99944e7 100644
--- a/pkg/sentry/fs/attr.go
+++ b/pkg/sentry/fs/attr.go
@@ -478,6 +478,20 @@ func (f FilePermissions) AnyRead() bool {
return f.User.Read || f.Group.Read || f.Other.Read
}
+// HasSetUIDOrGID returns true if either the setuid or setgid bit is set.
+func (f FilePermissions) HasSetUIDOrGID() bool {
+ return f.SetUID || f.SetGID
+}
+
+// DropSetUIDAndMaybeGID turns off setuid, and turns off setgid if f allows
+// group execution.
+func (f *FilePermissions) DropSetUIDAndMaybeGID() {
+ f.SetUID = false
+ if f.Group.Execute {
+ f.SetGID = false
+ }
+}
+
// FileOwner represents ownership of a file.
//
// +stateify savable
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index 5aa668873..a8591052c 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -161,7 +162,7 @@ func doCopyUp(ctx context.Context, d *Dirent) error {
// then try to take copyMu for writing here, we'd deadlock.
t := d.Inode.overlay.lower.StableAttr.Type
if t != RegularFile && t != Directory && t != Symlink {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Wait to get exclusive access to the upper Inode.
@@ -410,7 +411,7 @@ func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error
return err
}
lowerXattr, err := lower.ListXattr(ctx, linux.XATTR_SIZE_MAX)
- if err != nil && err != syserror.EOPNOTSUPP {
+ if err != nil && !linuxerr.Equals(linuxerr.EOPNOTSUPP, err) {
return err
}
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 23a3a9a2d..e28a8961b 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -18,6 +18,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/rand",
"//pkg/safemem",
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index e84ba7a5d..c62effd52 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -16,6 +16,7 @@
package dev
import (
+ "fmt"
"math"
"gvisor.dev/gvisor/pkg/context"
@@ -90,6 +91,11 @@ func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.In
// New returns the root node of a device filesystem.
func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
+ shm, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */)
+ if err != nil {
+ panic(fmt.Sprintf("tmpfs.NewDir failed: %v", err))
+ }
+
contents := map[string]*fs.Inode{
"fd": newSymlink(ctx, "/proc/self/fd", msrc),
"stdin": newSymlink(ctx, "/proc/self/fd/0", msrc),
@@ -108,7 +114,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"random": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, randomDevMinor),
"urandom": newMemDevice(ctx, newRandomDevice(ctx, fs.RootOwner, 0444), msrc, urandomDevMinor),
- "shm": tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc),
+ "shm": shm,
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 77e8d222a..5674978bd 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -17,6 +17,7 @@ package dev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -102,7 +103,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user
}
stack, ok := t.NetworkContext().(*netstack.Stack)
if !ok {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
var req linux.IFReq
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index 9d5d40954..e21c9d78e 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -22,6 +22,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
@@ -963,7 +964,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err
//
// See Linux equivalent in fs/namespace.c:do_add_mount.
if IsSymlink(inode.StableAttr) {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
// Dirent that'll replace d.
@@ -1439,7 +1440,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// replaced is the dirent that is being overwritten by rename.
replaced, err := newParent.walk(ctx, root, newName, false /* may unlock */)
if err != nil {
- if err != syserror.ENOENT {
+ if !linuxerr.Equals(linuxerr.ENOENT, err) {
return err
}
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index 2120f2bad..1bd2055d0 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -13,6 +13,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/log",
@@ -38,6 +39,7 @@ go_test(
library = ":fdpipe",
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/hostarch",
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 757b7d511..f8a29816b 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
@@ -158,7 +159,7 @@ func (p *pipeOperations) Write(ctx context.Context, file *fs.File, src usermem.I
// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
- if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK {
+ if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) {
return true
}
if pe, ok := err.(*os.PathError); ok {
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
index ab0e9dac7..6ea49cbb7 100644
--- a/pkg/sentry/fs/fdpipe/pipe_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -21,14 +21,14 @@ import (
"testing"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
func singlePipeFD() (int, error) {
@@ -214,7 +214,7 @@ func TestPipeRequest(t *testing.T) {
{
desc: "Fsync on pipe returns EINVAL",
context: &Fsync{},
- err: unix.EINVAL,
+ err: linuxerr.EINVAL,
},
{
desc: "Seek on pipe returns ESPIPE",
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 696613f3a..7e2f107e0 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -18,6 +18,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -417,7 +418,7 @@ func (f *overlayFileOperations) FifoSize(ctx context.Context, overlayFile *File)
err = f.onTop(ctx, overlayFile, func(file *File, ops FileOperations) error {
sz, ok := ops.(FifoSizer)
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rv, err = sz.FifoSize(ctx, file)
return err
@@ -432,11 +433,11 @@ func (f *overlayFileOperations) SetFifoSize(size int64) (rv int64, err error) {
if f.upper == nil {
// Named pipes cannot be copied up and changes to the lower are prohibited.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
sz, ok := f.upper.FileOperations.(FifoSizer)
if !ok {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
return sz.SetFifoSize(size)
}
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 6469cc3a9..ebc90b41f 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -76,6 +76,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/safemem",
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index dc9efa5df..c3525ba8e 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -18,6 +18,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -63,12 +64,12 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence,
switch inode.StableAttr.Type {
case fs.RegularFile, fs.SpecialFile, fs.BlockDevice:
if offset < 0 {
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
return offset, nil
case fs.Directory, fs.SpecialDirectory:
if offset != 0 {
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
// SEEK_SET to 0 moves the directory "cursor" to the beginning.
if dirCursor != nil {
@@ -76,22 +77,22 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence,
}
return 0, nil
default:
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
case fs.SeekCurrent:
switch inode.StableAttr.Type {
case fs.RegularFile, fs.SpecialFile, fs.BlockDevice:
if current+offset < 0 {
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
return current + offset, nil
case fs.Directory, fs.SpecialDirectory:
if offset != 0 {
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
return current, nil
default:
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
case fs.SeekEnd:
switch inode.StableAttr.Type {
@@ -103,14 +104,14 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence,
}
sz := uattr.Size
if sz+offset < 0 {
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
return sz + offset, nil
// FIXME(b/34778850): This is not universally correct.
// Remove SpecialDirectory.
case fs.SpecialDirectory:
if offset != 0 {
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
// SEEK_END to 0 moves the directory "cursor" to the end.
//
@@ -121,12 +122,12 @@ func SeekWithDirCursor(ctx context.Context, file *fs.File, whence fs.SeekWhence,
// futile (EOF will always be the result).
return fs.FileMaxOffset, nil
default:
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
}
// Not a valid seek request.
- return current, syserror.EINVAL
+ return current, linuxerr.EINVAL
}
// FileGenericSeek implements fs.FileOperations.Seek for files that use a
@@ -152,7 +153,7 @@ type FileNoSeek struct{}
// Seek implements fs.FileOperations.Seek.
func (FileNoSeek) Seek(context.Context, *fs.File, fs.SeekWhence, int64) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// FilePipeSeek implements fs.FileOperations.Seek and can be used for files
@@ -178,7 +179,7 @@ type FileNoFsync struct{}
// Fsync implements fs.FileOperations.Fsync.
func (FileNoFsync) Fsync(context.Context, *fs.File, int64, int64, fs.SyncType) error {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// FileNoopFsync implements fs.FileOperations.Fsync for files that don't need
@@ -345,7 +346,7 @@ func NewFileStaticContentReader(b []byte) FileStaticContentReader {
// Read implements fs.FileOperations.Read.
func (scr *FileStaticContentReader) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset >= int64(len(scr.content)) {
return 0, nil
@@ -367,7 +368,7 @@ type FileNoRead struct{}
// Read implements fs.FileOperations.Read.
func (FileNoRead) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// FileNoWrite implements fs.FileOperations.Write to return EINVAL.
@@ -375,7 +376,7 @@ type FileNoWrite struct{}
// Write implements fs.FileOperations.Write.
func (FileNoWrite) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// FileNoopRead implement fs.FileOperations.Read as a noop.
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index e1e38b498..8ac3738e9 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -155,12 +155,20 @@ func (h *HostMappable) DecRef(fr memmap.FileRange) {
// T2: Appends to file causing it to grow
// T2: Writes to mapped pages and COW happens
// T1: Continues and wronly invalidates the page mapped in step above.
-func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
+func (h *HostMappable) Truncate(ctx context.Context, newSize int64, uattr fs.UnstableAttr) error {
h.truncateMu.Lock()
defer h.truncateMu.Unlock()
mask := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: newSize}
+
+ // Truncating a file clears privilege bits.
+ if uattr.Perms.HasSetUIDOrGID() {
+ mask.Perms = true
+ attr.Perms = uattr.Perms
+ attr.Perms.DropSetUIDAndMaybeGID()
+ }
+
if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
return err
}
@@ -193,10 +201,17 @@ func (h *HostMappable) Allocate(ctx context.Context, offset int64, length int64)
}
// Write writes to the file backing this mappable.
-func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+func (h *HostMappable) Write(ctx context.Context, src usermem.IOSequence, offset int64, uattr fs.UnstableAttr) (int64, error) {
h.truncateMu.RLock()
+ defer h.truncateMu.RUnlock()
n, err := src.CopyInTo(ctx, &writer{ctx: ctx, hostMappable: h, off: offset})
- h.truncateMu.RUnlock()
+ if n > 0 && uattr.Perms.HasSetUIDOrGID() {
+ mask := fs.AttrMask{Perms: true}
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, uattr, false); err != nil {
+ return n, err
+ }
+ }
return n, err
}
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index 85e7e35db..bda07275d 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -17,6 +17,7 @@ package fsutil
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -376,7 +377,7 @@ func (InodeNotDirectory) RemoveDirectory(context.Context, *fs.Inode, string) err
// Rename implements fs.FileOperations.Rename.
func (InodeNotDirectory) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// InodeNotSocket can be used by Inodes that are not sockets.
@@ -392,7 +393,7 @@ type InodeNotTruncatable struct{}
// Truncate implements fs.InodeOperations.Truncate.
func (InodeNotTruncatable) Truncate(context.Context, *fs.Inode, int64) error {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// InodeIsDirTruncate implements fs.InodeOperations.Truncate for directories.
@@ -416,7 +417,7 @@ type InodeNotRenameable struct{}
// Rename implements fs.InodeOperations.Rename.
func (InodeNotRenameable) Rename(context.Context, *fs.Inode, *fs.Inode, string, *fs.Inode, string, bool) error {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// InodeNotOpenable can be used by Inodes that cannot be opened.
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 7856b354b..855029b84 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -310,6 +310,12 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode,
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: size}
+ if c.attr.Perms.HasSetUIDOrGID() {
+ masked.Perms = true
+ attr.Perms = c.attr.Perms
+ attr.Perms.DropSetUIDAndMaybeGID()
+ c.attr.Perms = attr.Perms
+ }
if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
c.dataMu.Unlock()
return err
@@ -685,13 +691,14 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return done, nil
}
-// maybeGrowFile grows the file's size if data has been written past the old
-// size.
+// maybeUpdateAttrs updates the file's attributes after a write. It updates
+// size if data has been written past the old size, and setuid/setgid if any
+// bytes were written.
//
// Preconditions:
// * rw.c.attrMu must be locked.
// * rw.c.dataMu must be locked.
-func (rw *inodeReadWriter) maybeGrowFile() {
+func (rw *inodeReadWriter) maybeUpdateAttrs(nwritten uint64) {
// If the write ends beyond the file's previous size, it causes the
// file to grow.
if rw.offset > rw.c.attr.Size {
@@ -705,6 +712,12 @@ func (rw *inodeReadWriter) maybeGrowFile() {
rw.c.attr.Usage = rw.offset
rw.c.dirtyAttr.Usage = true
}
+
+ // If bytes were written, ensure setuid and setgid are cleared.
+ if nwritten > 0 && rw.c.attr.Perms.HasSetUIDOrGID() {
+ rw.c.dirtyAttr.Perms = true
+ rw.c.attr.Perms.DropSetUIDAndMaybeGID()
+ }
}
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
@@ -732,7 +745,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
segMR := seg.Range().Intersect(mr)
ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write)
if err != nil {
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, err
}
@@ -744,7 +757,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
srcs = srcs.DropFirst64(n)
rw.c.dirty.MarkDirty(segMR)
if err != nil {
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, err
}
@@ -765,7 +778,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
srcs = srcs.DropFirst64(n)
// Partial writes are fine. But we must stop writing.
if n != src.NumBytes() || err != nil {
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, err
}
@@ -774,7 +787,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
}
}
- rw.maybeGrowFile()
+ rw.maybeUpdateAttrs(done)
rw.c.dataMu.Unlock()
return done, nil
}
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 94cb05246..c08301d19 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -26,6 +26,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/hostarch",
"//pkg/log",
diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go
index 07a564e92..f8b7a60fc 100644
--- a/pkg/sentry/fs/gofer/cache_policy.go
+++ b/pkg/sentry/fs/gofer/cache_policy.go
@@ -139,7 +139,7 @@ func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child
// Walk from parent to child again.
//
- // TODO(b/112031682): If we have a directory FD in the parent
+ // NOTE(b/112031682): If we have a directory FD in the parent
// inodeOperations, then we can use fstatat(2) to get the inode
// attributes instead of making this RPC.
qids, f, mask, attr, err := parentIops.fileState.file.walkGetAttr(ctx, []string{name})
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index bcdb2dda2..73d80d9b5 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -92,7 +92,6 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF
}
if flags.Write {
if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil {
- fsmetric.GoferOpensWX.Increment()
metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file")
log.Warningf("Opened a writable executable: %q", name)
}
@@ -238,10 +237,20 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I
// and availability of a host-mappable FD.
if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
- } else if f.inodeOperations.fileState.hostMappable != nil {
- n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
} else {
- n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+ uattr, e := f.UnstableAttr(ctx, file)
+ if e != nil {
+ return 0, e
+ }
+ if f.inodeOperations.fileState.hostMappable != nil {
+ n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset, uattr)
+ } else {
+ n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+ if n > 0 && uattr.Perms.HasSetUIDOrGID() {
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ f.inodeOperations.SetPermissions(ctx, file.Dirent.Inode, uattr.Perms)
+ }
+ }
}
if n == 0 {
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index b97635ec4..da3178527 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -600,11 +600,25 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length
if i.session().cachePolicy.useCachingInodeOps(inode) {
return i.cachingInodeOps.Truncate(ctx, inode, length)
}
+
+ uattr, err := i.fileState.unstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+
if i.session().cachePolicy == cacheRemoteRevalidating {
- return i.fileState.hostMappable.Truncate(ctx, length)
+ return i.fileState.hostMappable.Truncate(ctx, length, uattr)
+ }
+
+ mask := p9.SetAttrMask{Size: true}
+ attr := p9.SetAttr{Size: uint64(length)}
+ if uattr.Perms.HasSetUIDOrGID() {
+ mask.Permissions = true
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ attr.Permissions = p9.FileMode(uattr.Perms.LinuxMode())
}
- return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)})
+ return i.fileState.file.setAttr(ctx, mask, attr)
}
// GetXattr implements fs.InodeOperations.GetXattr.
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 6b3627813..1a6f353d0 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/device"
@@ -66,7 +67,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
// Get a p9.File for name.
qids, newFile, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name})
if err != nil {
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
if cp.cacheNegativeDirents() {
// Return a negative Dirent. It will stay cached until something
// is created over it.
@@ -130,7 +131,16 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string
panic(fmt.Sprintf("Create called with unknown or unset open flags: %v", flags))
}
+ // If the parent directory has setgid enabled, change the new file's owner.
owner := fs.FileOwnerFromContext(ctx)
+ parentUattr, err := dir.UnstableAttr(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ }
+
hostFile, err := newFile.create(ctx, name, openFlags, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
if err != nil {
// Could not create the file.
@@ -225,7 +235,18 @@ func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s
return syserror.ENAMETOOLONG
}
+ // If the parent directory has setgid enabled, change the new directory's
+ // owner and enable setgid.
owner := fs.FileOwnerFromContext(ctx)
+ parentUattr, err := dir.UnstableAttr(ctx)
+ if err != nil {
+ return err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ perm.SetGID = true
+ }
+
if _, err := i.fileState.file.mkdir(ctx, s, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
return err
}
@@ -278,7 +299,7 @@ func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name st
// N.B. FIFOs use major/minor numbers 0.
if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
- if i.session().overrides == nil || err != syserror.EPERM {
+ if i.session().overrides == nil || !linuxerr.Equals(linuxerr.EPERM, err) {
return err
}
// If gofer doesn't support mknod, check if we can create an internal fifo.
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 3c45f6cc5..24fc6305c 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -28,9 +28,9 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/fdnotifier",
- "//pkg/iovec",
"//pkg/log",
"//pkg/marshal/primitive",
"//pkg/refs",
@@ -40,6 +40,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/hostfd",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 46a2dc47d..225244868 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/refs"
@@ -213,7 +214,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess
// block (and only for stream sockets).
err = syserror.EAGAIN
}
- if n > 0 && err != syserror.EAGAIN {
+ if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) {
// The caller may need to block to send more data, but
// otherwise there isn't anything that can be done about an
// error with a partial write.
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
index 7380d75e7..fd48aff11 100644
--- a/pkg/sentry/fs/host/socket_iovec.go
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -16,7 +16,7 @@ package host
import (
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/iovec"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -72,7 +72,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
}
}
- if iovsRequired > iovec.MaxIovs {
+ if iovsRequired > hostfd.MaxSendRecvMsgIov {
// The kernel will reject our call if we pass this many iovs.
// Use a single intermediate buffer instead.
b := make([]byte, stopLen)
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 1183727ab..2ff520100 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -17,6 +17,7 @@ package host
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -191,7 +192,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO
if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
// drivers/tty/tty_io.c:tiocspgrp() converts -EIO from
// tty_check_change() to -ENOTTY.
- if err == syserror.EIO {
+ if linuxerr.Equals(linuxerr.EIO, err) {
return 0, syserror.ENOTTY
}
return 0, err
@@ -211,7 +212,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO
// pgID must be non-negative.
if pgID < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Process group with pgID must exist in this PID namespace.
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
index ab74724a3..e7db79189 100644
--- a/pkg/sentry/fs/host/util.go
+++ b/pkg/sentry/fs/host/util.go
@@ -19,12 +19,12 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/syserror"
)
func nodeType(s *unix.Stat_t) fs.InodeType {
@@ -98,7 +98,7 @@ type dirInfo struct {
// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
- if err == syserror.EAGAIN || err == syserror.EWOULDBLOCK {
+ if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) {
return true
}
if pe, ok := err.(*os.PathError); ok {
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index e97afc626..bd1125dcc 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
@@ -71,7 +72,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
// A file could have been created over a whiteout, so we need to
// check if something exists in the upper file system first.
child, err := parent.upper.Lookup(ctx, name)
- if err != nil && err != syserror.ENOENT {
+ if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) {
// We encountered an error that an overlay cannot handle,
// we must propagate it to the caller.
parent.copyMu.RUnlock()
@@ -125,7 +126,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
// Check the lower file system.
child, err := parent.lower.Lookup(ctx, name)
// Same song and dance as above.
- if err != nil && err != syserror.ENOENT {
+ if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) {
// Don't leak resources.
if upperInode != nil {
upperInode.DecRef(ctx)
@@ -396,7 +397,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
// newName has been removed out from under us. That's fine;
// filesystems where that can happen must handle stale
// 'replaced'.
- if err != nil && err != syserror.ENOENT {
+ if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) {
return err
}
if err == nil {
diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go
index aa9851b26..cc5ffa6f1 100644
--- a/pkg/sentry/fs/inode_overlay_test.go
+++ b/pkg/sentry/fs/inode_overlay_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
@@ -191,11 +192,11 @@ func TestLookup(t *testing.T) {
} {
t.Run(test.desc, func(t *testing.T) {
dirent, err := test.dir.Lookup(ctx, test.name)
- if test.found && (err == syserror.ENOENT || dirent.IsNegative()) {
+ if test.found && (linuxerr.Equals(linuxerr.ENOENT, err) || dirent.IsNegative()) {
t.Fatalf("lookup %q expected to find positive dirent, got dirent %v err %v", test.name, dirent, err)
}
if !test.found {
- if err != syserror.ENOENT && !dirent.IsNegative() {
+ if !linuxerr.Equals(linuxerr.ENOENT, err) && !dirent.IsNegative() {
t.Errorf("lookup %q expected to return ENOENT or negative dirent, got dirent %v err %v", test.name, dirent, err)
}
// Nothing more to check.
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index 1b83643db..4e07043c7 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -132,7 +133,7 @@ func (*Inotify) Write(context.Context, *File, usermem.IOSequence, int64) (int64,
// Read implements FileOperations.Read.
func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ int64) (int64, error) {
if dst.NumBytes() < inotifyEventBaseSize {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
i.evMu.Lock()
@@ -156,7 +157,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
// write some events out.
return writeLen, nil
}
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Linux always dequeues an available event as long as there's enough
@@ -183,7 +184,7 @@ func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64,
// Fsync implements FileOperations.Fsync.
func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// ReadFrom implements FileOperations.ReadFrom.
@@ -329,7 +330,7 @@ func (i *Inotify) RmWatch(ctx context.Context, wd int32) error {
watch, ok := i.watches[wd]
if !ok {
i.mu.Unlock()
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Remove the watch from this instance.
diff --git a/pkg/sentry/fs/mock.go b/pkg/sentry/fs/mock.go
index 1d6ea5736..2a54c1242 100644
--- a/pkg/sentry/fs/mock.go
+++ b/pkg/sentry/fs/mock.go
@@ -16,6 +16,7 @@ package fs
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -109,7 +110,7 @@ func (n *MockInodeOperations) SetPermissions(context.Context, *Inode, FilePermis
// SetOwner implements fs.InodeOperations.SetOwner.
func (*MockInodeOperations) SetOwner(context.Context, *Inode, FileOwner) error {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// SetTimestamps implements fs.InodeOperations.SetTimestamps.
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 243098a09..340441974 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sync"
@@ -357,7 +358,7 @@ func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly
orig, ok := mns.mounts[node]
if !ok {
// node is not a mount point.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if orig.previous == nil {
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
index f96f5a3e5..7e72e47b5 100644
--- a/pkg/sentry/fs/overlay.go
+++ b/pkg/sentry/fs/overlay.go
@@ -19,11 +19,11 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// The virtual filesystem implements an overlay configuration. For a high-level
@@ -218,7 +218,7 @@ func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExist
// We don't support copying up from character devices,
// named pipes, or anything weird (like proc files).
log.Warningf("%s not supported in lower filesytem", lower.StableAttr.Type)
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
return &overlayEntry{
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 7af7e0b45..e6d74b949 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -30,6 +30,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
index 24426b225..379429ab2 100644
--- a/pkg/sentry/fs/proc/exec_args.go
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -21,11 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -104,7 +104,7 @@ var _ fs.FileOperations = (*execArgFile)(nil)
// Read reads the exec arg from the process's address space..
func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
m, err := getTaskMM(f.t)
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 91c35eea9..187e9a921 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -34,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -291,7 +291,7 @@ func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
continue
}
if err := n.s.Statistics(stat, line.prefix); err != nil {
- if err == syserror.EOPNOTSUPP {
+ if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) {
log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
} else {
log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index 2f2a9f920..546b57287 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -21,6 +21,7 @@ import (
"strconv"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
@@ -130,7 +131,7 @@ func (s *self) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
}
// Who is reading this link?
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
// threadSelf is more magical than "self" link.
@@ -154,7 +155,7 @@ func (s *threadSelf) Readlink(ctx context.Context, inode *fs.Inode) (string, err
}
// Who is reading this link?
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
// Lookup loads an Inode at name into a Dirent.
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index b998fb75d..085aa6d61 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -77,6 +77,27 @@ func (*overcommitMemory) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandl
}, 0
}
+// +stateify savable
+type maxMapCount struct{}
+
+// NeedsUpdate implements seqfile.SeqSource.
+func (*maxMapCount) NeedsUpdate(int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.
+func (*maxMapCount) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+ return []seqfile.SeqData{
+ {
+ Buf: []byte("2147483647\n"),
+ Handle: (*maxMapCount)(nil),
+ },
+ }, 0
+}
+
func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
h := hostname{
SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
@@ -96,6 +117,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
children := map[string]*fs.Inode{
+ "max_map_count": seqfile.NewSeqFileInode(ctx, &maxMapCount{}, msrc),
"mmap_min_addr": seqfile.NewSeqFileInode(ctx, &mmapMinAddrData{p.k}, msrc),
"overcommit_memory": seqfile.NewSeqFileInode(ctx, &overcommitMemory{}, msrc),
}
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 4893af56b..71f37d582 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -592,7 +592,7 @@ func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSe
// Port numbers must be uint16s.
if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if err := pf.inode.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil {
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index ae5ed25f9..7ece1377a 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -867,7 +868,7 @@ var _ fs.FileOperations = (*commFile)(nil)
// Read implements fs.FileOperations.Read.
func (f *commFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
buf := []byte(f.t.Name() + "\n")
@@ -922,7 +923,7 @@ type auxvecFile struct {
// Read implements fs.FileOperations.Read.
func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
m, err := getTaskMM(f.t)
diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go
index 30d5ad4cf..fcdc1e7bd 100644
--- a/pkg/sentry/fs/proc/uid_gid_map.go
+++ b/pkg/sentry/fs/proc/uid_gid_map.go
@@ -21,12 +21,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -108,7 +108,7 @@ const maxIDMapLines = 5
// Read implements fs.FileOperations.Read.
func (imfo *idMapFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
var entries []auth.IDMapEntry
if imfo.iops.gids {
@@ -134,7 +134,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u
// the file ..." - user_namespaces(7)
srclen := src.NumBytes()
if srclen >= hostarch.PageSize || offset != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
b := make([]byte, srclen)
if _, err := src.CopyIn(ctx, b); err != nil {
@@ -154,7 +154,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u
}
lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1)
if len(lines) > maxIDMapLines {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
entries := make([]auth.IDMapEntry, len(lines))
@@ -162,7 +162,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u
var e auth.IDMapEntry
_, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length)
if err != nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
entries[i] = e
}
diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go
index c0f6fb802..ac896f963 100644
--- a/pkg/sentry/fs/proc/uptime.go
+++ b/pkg/sentry/fs/proc/uptime.go
@@ -20,10 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -74,7 +74,7 @@ type uptimeFile struct {
// Read implements fs.FileOperations.Read.
func (f *uptimeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
now := ktime.NowFromContext(ctx)
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index 33da82868..ca9f645f6 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -139,7 +140,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
// Attempt to do a WriteTo; this is likely the most efficient.
n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
- if n == 0 && err == syserror.ENOSYS && !opts.Dup {
+ if n == 0 && linuxerr.Equals(linuxerr.ENOSYS, err) && !opts.Dup {
// Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
// more efficient than a copy if buffers are cached or readily
// available. (It's unlikely that they can actually be donated).
@@ -151,7 +152,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
// if we block at some point, we could lose data. If the source is
// not a pipe then reading is not destructive; if the destination
// is a regular file, then it is guaranteed not to block writing.
- if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) {
+ if n == 0 && linuxerr.Equals(linuxerr.ENOSYS, err) && !opts.Dup && (!dstPipe || !srcPipe) {
// Fallback to an in-kernel copy.
n, err = io.Copy(w, &io.LimitedReader{
R: r,
diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD
index c7977a217..0148b33cf 100644
--- a/pkg/sentry/fs/timerfd/BUILD
+++ b/pkg/sentry/fs/timerfd/BUILD
@@ -8,6 +8,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index c8ebe256c..093a14c1f 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -20,6 +20,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
@@ -121,7 +122,7 @@ func (t *TimerOperations) EventUnregister(e *waiter.Entry) {
func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
const sizeofUint64 = 8
if dst.NumBytes() < sizeofUint64 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if val := atomic.SwapUint64(&t.val, 0); val != 0 {
var buf [sizeofUint64]byte
@@ -138,7 +139,7 @@ func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.I
// Write implements fs.FileOperations.Write.
func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Notify implements ktime.TimerListener.Notify.
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 90398376a..c36a20afe 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -15,6 +15,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/safemem",
"//pkg/sentry/device",
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
index bc117ca6a..b48d475ed 100644
--- a/pkg/sentry/fs/tmpfs/fs.go
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -151,5 +151,5 @@ func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
}
// Construct the tmpfs root.
- return NewDir(ctx, nil, owner, perms, msrc), nil
+ return NewDir(ctx, nil, owner, perms, msrc, nil /* parent */)
}
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index f4de8c968..ce6be6386 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -226,6 +227,12 @@ func (f *fileInodeOperations) Truncate(ctx context.Context, _ *fs.Inode, size in
now := ktime.NowFromContext(ctx)
f.attr.ModificationTime = now
f.attr.StatusChangeTime = now
+
+ // Truncating clears privilege bits.
+ f.attr.Perms.SetUID = false
+ if f.attr.Perms.Group.Execute {
+ f.attr.Perms.SetGID = false
+ }
}
f.dataMu.Unlock()
@@ -363,7 +370,14 @@ func (f *fileInodeOperations) write(ctx context.Context, src usermem.IOSequence,
now := ktime.NowFromContext(ctx)
f.attr.ModificationTime = now
f.attr.StatusChangeTime = now
- return src.CopyInTo(ctx, &fileReadWriter{f, offset})
+ nwritten, err := src.CopyInTo(ctx, &fileReadWriter{f, offset})
+
+ // Writing clears privilege bits.
+ if nwritten > 0 {
+ f.attr.Perms.DropSetUIDAndMaybeGID()
+ }
+
+ return nwritten, err
}
type fileReadWriter struct {
@@ -442,7 +456,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
if end == math.MaxInt64 {
// Overflow.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Check if seals prevent either file growth or all writes.
@@ -642,7 +656,7 @@ func GetSeals(inode *fs.Inode) (uint32, error) {
return f.seals, nil
}
// Not a memfd inode.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// AddSeals adds new file seals to a memfd inode.
@@ -670,5 +684,5 @@ func AddSeals(inode *fs.Inode, val uint32) error {
return nil
}
// Not a memfd inode.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 577052888..6aa8ff331 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -87,7 +87,20 @@ type Dir struct {
var _ fs.InodeOperations = (*Dir)(nil)
// NewDir returns a new directory.
-func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
+func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource, parent *fs.Inode) (*fs.Inode, error) {
+ // If the parent has setgid enabled, the new directory enables it and changes
+ // its GID.
+ if parent != nil {
+ parentUattr, err := parent.UnstableAttr(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ perms.SetGID = true
+ }
+ }
+
d := &Dir{
ramfsDir: ramfs.NewDir(ctx, contents, owner, perms),
kernel: kernel.KernelFromContext(ctx),
@@ -101,7 +114,7 @@ func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwn
InodeID: tmpfsDevice.NextIno(),
BlockSize: hostarch.PageSize,
Type: fs.Directory,
- })
+ }), nil
}
// afterLoad is invoked by stateify.
@@ -219,11 +232,21 @@ func (d *Dir) SetTimestamps(ctx context.Context, i *fs.Inode, ts fs.TimeSpec) er
func (d *Dir) newCreateOps() *ramfs.CreateOps {
return &ramfs.CreateOps{
NewDir: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
- return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource), nil
+ return NewDir(ctx, nil, fs.FileOwnerFromContext(ctx), perms, dir.MountSource, dir)
},
NewFile: func(ctx context.Context, dir *fs.Inode, perms fs.FilePermissions) (*fs.Inode, error) {
+ // If the parent has setgid enabled, change the GID of the new file.
+ owner := fs.FileOwnerFromContext(ctx)
+ parentUattr, err := dir.UnstableAttr(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if parentUattr.Perms.SetGID {
+ owner.GID = parentUattr.Owner.GID
+ }
+
uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
- Owner: fs.FileOwnerFromContext(ctx),
+ Owner: owner,
Perms: perms,
// Always start unlinked.
Links: 0,
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 86ada820e..5933cb67b 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -17,6 +17,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/marshal/primitive",
"//pkg/refs",
diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go
index 13f4901db..0e5916380 100644
--- a/pkg/sentry/fs/tty/fs.go
+++ b/pkg/sentry/fs/tty/fs.go
@@ -16,9 +16,9 @@ package tty
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// ptsDevice is the pseudo-filesystem device.
@@ -64,7 +64,7 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// No options are supported.
if data != "" {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
return newDir(ctx, fs.NewMountSource(ctx, &superOperations{}, f, flags)), nil
diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD
index 66e949c95..4acc73ee0 100644
--- a/pkg/sentry/fs/user/BUILD
+++ b/pkg/sentry/fs/user/BUILD
@@ -12,6 +12,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/log",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
index 124bc95ed..f6eaab2bd 100644
--- a/pkg/sentry/fs/user/path.go
+++ b/pkg/sentry/fs/user/path.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -93,7 +94,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
binPath := path.Join(p, name)
traversals := uint(linux.MaxSymlinkTraversals)
d, err := mns.FindInode(ctx, root, nil, binPath, &traversals)
- if err == syserror.ENOENT || err == syserror.EACCES {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.EACCES, err) {
// Didn't find it here.
continue
}
@@ -142,7 +143,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam
Flags: linux.O_RDONLY,
}
dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts)
- if err == syserror.ENOENT || err == syserror.EACCES {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.EACCES, err) {
// Didn't find it here.
continue
}
diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go
index 12b786224..7f8fa8038 100644
--- a/pkg/sentry/fs/user/user_test.go
+++ b/pkg/sentry/fs/user/user_test.go
@@ -104,7 +104,10 @@ func TestGetExecUserHome(t *testing.T) {
t.Run(name, func(t *testing.T) {
ctx := contexttest.Context(t)
msrc := fs.NewPseudoMountSource(ctx)
- rootInode := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc)
+ rootInode, err := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc, nil /* parent */)
+ if err != nil {
+ t.Fatalf("tmpfs.NewDir failed: %v", err)
+ }
mns, err := fs.NewMountNamespace(ctx, rootInode)
if err != nil {
diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD
index 37efb641a..4c9c5b344 100644
--- a/pkg/sentry/fsimpl/cgroupfs/BUILD
+++ b/pkg/sentry/fsimpl/cgroupfs/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/coverage",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go
index 6512e9cdb..4290ffe0d 100644
--- a/pkg/sentry/fsimpl/cgroupfs/base.go
+++ b/pkg/sentry/fsimpl/cgroupfs/base.go
@@ -23,10 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -133,6 +133,17 @@ func (c *cgroupInode) Controllers() []kernel.CgroupController {
return c.fs.kcontrollers
}
+// tasks returns a snapshot of the tasks inside the cgroup.
+func (c *cgroupInode) tasks() []*kernel.Task {
+ c.fs.tasksMu.RLock()
+ defer c.fs.tasksMu.RUnlock()
+ ts := make([]*kernel.Task, 0, len(c.ts))
+ for t := range c.ts {
+ ts = append(ts, t)
+ }
+ return ts
+}
+
// Enter implements kernel.CgroupImpl.Enter.
func (c *cgroupInode) Enter(t *kernel.Task) {
c.fs.tasksMu.Lock()
@@ -163,10 +174,7 @@ func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error
pgids := make(map[kernel.ThreadID]struct{})
- d.fs.tasksMu.RLock()
- defer d.fs.tasksMu.RUnlock()
-
- for task := range d.ts {
+ for _, task := range d.tasks() {
// Map dedups pgid, since iterating over all tasks produces multiple
// entries for the group leaders.
if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 {
@@ -205,10 +213,7 @@ func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var pids []kernel.ThreadID
- d.fs.tasksMu.RLock()
- defer d.fs.tasksMu.RUnlock()
-
- for task := range d.ts {
+ for _, task := range d.tasks() {
if pid := currPidns.IDOfTask(task); pid != 0 {
pids = append(pids, pid)
}
@@ -248,7 +253,7 @@ func parseInt64FromString(ctx context.Context, src usermem.IOSequence, offset in
// Note: This also handles zero-len writes if offset is beyond the end
// of src, or src is empty.
ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err)
- return 0, int64(n), syserror.EINVAL
+ return 0, int64(n), linuxerr.EINVAL
}
return val, int64(n), nil
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
index 54050de3c..b5883cbd2 100644
--- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -49,8 +49,9 @@
//
// kernel.CgroupRegistry.mu
// cgroupfs.filesystem.mu
-// Task.mu
-// cgroupfs.filesystem.tasksMu.
+// kernel.TaskSet.mu
+// kernel.Task.mu
+// cgroupfs.filesystem.tasksMu.
package cgroupfs
import (
@@ -61,6 +62,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -166,7 +168,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
if err != nil {
ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
}
@@ -194,7 +196,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if _, ok := mopts["all"]; ok {
if len(wantControllers) > 0 {
ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
delete(mopts, "all")
@@ -208,7 +210,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if len(mopts) != 0 {
ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
k := kernel.KernelFromContext(ctx)
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 6af3c3781..50b4c02ef 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -29,6 +29,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index e75954105..7a488e9fd 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -56,7 +57,7 @@ func (*FilesystemType) Name() string {
func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// No data allowed.
if opts.Data != "" {
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fstype.initOnce.Do(func() {
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 93c031c89..1374fd3be 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -17,6 +17,7 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -80,7 +81,7 @@ func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs
// SetStat implements kernfs.Inode.SetStat
func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask&linux.STATX_SIZE != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
}
diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go
index 96d2054cb..81572b991 100644
--- a/pkg/sentry/fsimpl/devpts/replica.go
+++ b/pkg/sentry/fsimpl/devpts/replica.go
@@ -17,6 +17,7 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -92,7 +93,7 @@ func (ri *replicaInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vf
// SetStat implements kernfs.Inode.SetStat
func (ri *replicaInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask&linux.STATX_SIZE != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
return ri.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
}
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 2dbc6bfd5..5e8b464a0 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -47,6 +47,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/fspath",
"//pkg/log",
@@ -88,13 +89,13 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/marshal/primitive",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/test/testutil",
"//pkg/usermem",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index 1165234f9..79719faed 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -18,6 +18,7 @@ import (
"io"
"math"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -84,7 +85,7 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) {
}
if off < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
offset := uint64(off)
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 512b70ede..cc067c20e 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -17,12 +17,12 @@ package ext
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// directory represents a directory inode. It holds the childList in memory.
@@ -218,7 +218,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
if whence != linux.SEEK_SET && whence != linux.SEEK_CUR {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
dir := fd.inode().impl.(*directory)
@@ -234,7 +234,7 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
if offset < 0 {
// lseek(2) specifies that EINVAL should be returned if the resulting offset
// is negative.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
n := int64(len(dir.childMap))
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index 38fb7962b..80854b501 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -22,12 +22,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Name is the name of this filesystem.
@@ -133,13 +133,13 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// mount(2) specifies that EINVAL should be returned if the superblock is
// invalid.
fs.vfsfs.DecRef(ctx)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// Refuse to mount if the filesystem is incompatible.
if !isCompatible(fs.sb) {
fs.vfsfs.DecRef(ctx)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fs.bgs, err = readBlockGroups(dev, fs.sb)
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index d9fd4590c..db712e71f 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -26,12 +26,12 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -173,7 +173,7 @@ func TestSeek(t *testing.T) {
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); !linuxerr.Equals(linuxerr.EINVAL, err) {
t.Errorf("expected error EINVAL but got %v", err)
}
@@ -187,7 +187,7 @@ func TestSeek(t *testing.T) {
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); !linuxerr.Equals(linuxerr.EINVAL, err) {
t.Errorf("expected error EINVAL but got %v", err)
}
@@ -204,7 +204,7 @@ func TestSeek(t *testing.T) {
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); !linuxerr.Equals(linuxerr.EINVAL, err) {
t.Errorf("expected error EINVAL but got %v", err)
}
}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 778460107..f449bc8bd 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -18,6 +18,7 @@ import (
"io"
"sort"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -65,7 +66,7 @@ func (f *extentFile) buildExtTree() error {
if f.root.Header.NumEntries > 4 {
// read(2) specifies that EINVAL should be returned if the file is unsuitable
// for reading.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries)
@@ -145,7 +146,7 @@ func (f *extentFile) ReadAt(dst []byte, off int64) (int, error) {
}
if off < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if uint64(off) >= f.regFile.inode.diskInode.Size() {
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index d4fc484a2..1d2eaa0d4 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -344,7 +345,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
}
symlink, ok := inode.impl.(*symlink)
if !ok {
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
return symlink.target, nil
}
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 4a555bf72..b3df2337f 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -147,7 +148,7 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
return &f.inode, nil
default:
// TODO(b/134676337): Return appropriate errors for sockets, pipes and devices.
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index 5ad9befcd..9a094716a 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -139,10 +140,10 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
case linux.SEEK_END:
offset += int64(fd.inode().diskInode.Size())
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
fd.off = offset
return offset, nil
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 3a4777fbe..871df5984 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -46,6 +46,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
@@ -76,6 +77,7 @@ go_test(
library = ":fuse",
deps = [
"//pkg/abi/linux",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/marshal",
"//pkg/sentry/fsimpl/testutil",
diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go
index 78ea6a31e..1fddd858e 100644
--- a/pkg/sentry/fsimpl/fuse/connection_test.go
+++ b/pkg/sentry/fsimpl/fuse/connection_test.go
@@ -19,9 +19,9 @@ import (
"testing"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
)
// TestConnectionInitBlock tests if initialization
@@ -104,7 +104,7 @@ func TestConnectionAbort(t *testing.T) {
// After abort, Call() should return directly with ENOTCONN.
req := conn.NewRequest(creds, 0, 0, 0, testObj)
_, err = conn.Call(task, req)
- if err != syserror.ENOTCONN {
+ if !linuxerr.Equals(linuxerr.ENOTCONN, err) {
t.Fatalf("Incorrect error code received for Call() after connection aborted")
}
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index 5d2bae14e..0d0eed543 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -18,6 +18,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -149,7 +150,7 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R
// If the read buffer is too small, error out.
if dst.NumBytes() < int64(minBuffSize) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
fd.mu.Lock()
@@ -293,7 +294,7 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt
// Assert that the header isn't read into the writeBuf yet.
if fd.writeCursor >= hdrLen {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// We don't have the full common response header yet.
@@ -322,7 +323,7 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt
if !ok {
// Server sent us a response for a request we never sent,
// or for which we already received a reply (e.g. aborted), an unlikely event.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
delete(fd.completions, hdr.Unique)
@@ -434,7 +435,7 @@ func (fd *DeviceFD) sendError(ctx context.Context, errno int32, unique linux.FUS
if !ok {
// A response for a request we never sent,
// or for which we already received a reply (e.g. aborted).
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
delete(fd.completions, respHdr.Unique)
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 167c899e2..be5bcd6af 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -121,30 +122,30 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
deviceDescriptorStr, ok := mopts["fd"]
if !ok {
ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option fd missing")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
delete(mopts, "fd")
deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
if err != nil {
ctx.Debugf("fusefs.FilesystemType.GetFilesystem: invalid fd: %q (%v)", deviceDescriptorStr, err)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name())
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor))
if fuseFDGeneric == nil {
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
defer fuseFDGeneric.DecRef(ctx)
fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD)
if !ok {
log.Warningf("%s.GetFilesystem: device FD is %T, not a FUSE device", fsType.Name, fuseFDGeneric)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// Parse and set all the other supported FUSE mount options.
@@ -154,17 +155,17 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
uid, err := strconv.ParseUint(uidStr, 10, 32)
if err != nil {
log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), uidStr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
if !kuid.Ok() {
ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fsopts.uid = kuid
} else {
ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option user_id missing")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
if gidStr, ok := mopts["group_id"]; ok {
@@ -172,17 +173,17 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
gid, err := strconv.ParseUint(gidStr, 10, 32)
if err != nil {
log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), gidStr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
if !kgid.Ok() {
ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fsopts.gid = kgid
} else {
ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option group_id missing")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
if modeStr, ok := mopts["rootmode"]; ok {
@@ -190,12 +191,12 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mode, err := strconv.ParseUint(modeStr, 8, 32)
if err != nil {
log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fsopts.rootMode = linux.FileMode(mode)
} else {
ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option rootmode missing")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// Set the maxInFlightRequests option.
@@ -206,7 +207,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
maxRead, err := strconv.ParseUint(maxReadStr, 10, 32)
if err != nil {
log.Warningf("%s.GetFilesystem: invalid max_read: max_read=%s", fsType.Name(), maxReadStr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
if maxRead < fuseMinMaxRead {
maxRead = fuseMinMaxRead
@@ -229,7 +230,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Check for unparsed options.
if len(mopts) != 0 {
log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// Create a new FUSE filesystem.
@@ -258,7 +259,7 @@ func newFUSEFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, fsTyp
conn, err := newFUSEConnection(ctx, fuseFD, opts)
if err != nil {
log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err)
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
fs := &filesystem{
@@ -418,7 +419,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.Open: couldn't get kernel task from context")
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
// Build the request.
@@ -440,7 +441,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
if err != nil {
return nil, err
}
- if err := res.Error(); err == syserror.ENOSYS && !isDir {
+ if err := res.Error(); linuxerr.Equals(linuxerr.ENOSYS, err) && !isDir {
i.fs.conn.noOpen = true
} else if err != nil {
return nil, err
@@ -512,7 +513,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions)
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID)
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
in := linux.FUSECreateIn{
CreateMeta: linux.FUSECreateMeta{
@@ -552,7 +553,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
in := linux.FUSEUnlinkIn{Name: name}
req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
@@ -596,7 +597,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
res, err := i.fs.conn.Call(kernelTask, req)
@@ -626,13 +627,13 @@ func (i *inode) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry,
// Readlink implements kernfs.Inode.Readlink.
func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
if i.Mode().FileType()&linux.S_IFLNK == 0 {
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
if len(i.link) == 0 {
kernelTask := kernel.TaskFromContext(ctx)
if kernelTask == nil {
log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context")
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
res, err := i.fs.conn.Call(kernelTask, req)
@@ -728,7 +729,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
task := kernel.TaskFromContext(ctx)
if task == nil {
log.Warningf("couldn't get kernel task from context")
- return linux.FUSEAttr{}, syserror.EINVAL
+ return linux.FUSEAttr{}, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(ctx)
@@ -833,7 +834,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
task := kernel.TaskFromContext(ctx)
if task == nil {
log.Warningf("couldn't get kernel task from context")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// We should retain the original file type when assigning new mode.
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 66ea889f9..35d0ab6f4 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -39,7 +40,7 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
t := kernel.TaskFromContext(ctx)
if t == nil {
log.Warningf("fusefs.Read: couldn't get kernel task from context")
- return nil, 0, syserror.EINVAL
+ return nil, 0, linuxerr.EINVAL
}
// Round up to a multiple of page size.
@@ -155,7 +156,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
t := kernel.TaskFromContext(ctx)
if t == nil {
log.Warningf("fusefs.Read: couldn't get kernel task from context")
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// One request cannnot exceed either maxWrite or maxPages.
diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go
index 5bdd096c3..a0802cd32 100644
--- a/pkg/sentry/fsimpl/fuse/regular_file.go
+++ b/pkg/sentry/fsimpl/fuse/regular_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -39,7 +40,7 @@ type regularFileFD struct {
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Check that flags are supported.
@@ -56,7 +57,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
} else if size > math.MaxUint32 {
// FUSE only supports uint32 for size.
// Overflow.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// TODO(gvisor.dev/issue/3678): Add direct IO support.
@@ -143,7 +144,7 @@ func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts
// final offset should be ignored by PWrite.
func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, offset, syserror.EINVAL
+ return 0, offset, linuxerr.EINVAL
}
// Check that flags are supported.
@@ -171,11 +172,11 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
if srclen > math.MaxUint32 {
// FUSE only supports uint32 for size.
// Overflow.
- return 0, offset, syserror.EINVAL
+ return 0, offset, linuxerr.EINVAL
}
if end := offset + srclen; end < offset {
// Overflow.
- return 0, offset, syserror.EINVAL
+ return 0, offset, linuxerr.EINVAL
}
srclen, err = vfs.CheckLimit(ctx, offset, srclen)
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 368272f12..752060044 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -49,6 +49,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/fspath",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 177e42649..5c48a9fee 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refsvfs2"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
func (d *dentry) isDir() bool {
@@ -297,7 +297,7 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
switch whence {
case linux.SEEK_SET:
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset == 0 {
// Ensure that the next call to fd.IterDirents() calls
@@ -309,13 +309,13 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
case linux.SEEK_CUR:
offset += fd.off
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Don't clear fd.dirents in this case, even if offset == 0.
fd.off = offset
return fd.off, nil
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 91ec4a142..067b7aac1 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
@@ -255,7 +256,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
if err != nil {
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
parent.cacheNegativeLookupLocked(name)
}
return nil, err
@@ -382,7 +383,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return syserror.EEXIST
}
checkExistence := func() error {
- if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT {
+ if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) {
return err
} else if child != nil {
return syserror.EEXIST
@@ -469,7 +470,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
name := rp.Component()
if dir {
if name == "." {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if name == ".." {
return syserror.ENOTEMPTY
@@ -715,7 +716,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
mode |= linux.S_ISGID
}
if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil {
- if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
+ if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) {
return err
}
ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
@@ -752,7 +753,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
creds := rp.Credentials()
_, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- if err != syserror.EPERM {
+ if !linuxerr.Equals(linuxerr.EPERM, err) {
return err
}
@@ -765,7 +766,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
case err == nil:
// Step succeeded, another file exists.
return syserror.EEXIST
- case err != syserror.ENOENT:
+ case !linuxerr.Equals(linuxerr.ENOENT, err):
// Unexpected error.
return err
}
@@ -862,7 +863,7 @@ afterTrailingSymlink:
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
- if err == syserror.ENOENT && mayCreate {
+ if linuxerr.Equals(linuxerr.ENOENT, err) && mayCreate {
if parent.isSynthetic() {
parent.dirMu.Unlock()
return nil, syserror.EPERM
@@ -942,7 +943,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, syserror.EISDIR
}
if opts.Flags&linux.O_DIRECT != 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
if !d.isSynthetic() {
if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil {
@@ -998,7 +999,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
if opts.Flags&linux.O_DIRECT != 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
fdObj, err := d.file.connect(ctx, p9.AnonymousSocket)
if err != nil {
@@ -1019,7 +1020,7 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio
func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(opts)
if opts.Flags&linux.O_DIRECT != 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
// We assume that the server silently inserts O_NONBLOCK in the open flags
// for all named pipes (because all existing gofers do this).
@@ -1033,7 +1034,7 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.
retry:
h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
if err != nil {
- if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO {
+ if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) {
// An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails
// with ENXIO if opening the same named pipe with O_WRONLY would
// block because there are no readers of the pipe.
@@ -1187,18 +1188,14 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
return "", err
}
if !d.isSymlink() {
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
return d.readlink(ctx, rp.Mount())
}
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- if opts.Flags != 0 {
- // Requires 9P support.
- return syserror.EINVAL
- }
-
+ // Resolve newParent first to verify that it's on this Mount.
var ds *[]*dentry
fs.renameMu.Lock()
defer fs.renameMuUnlockAndCheckCaching(ctx, &ds)
@@ -1206,8 +1203,21 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
+
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return linuxerr.EINVAL
+ }
+ if fs.opts.interop == InteropModeShared && opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ // Requires 9P support to synchronize with other remote filesystem
+ // users.
+ return linuxerr.EINVAL
+ }
+
newName := rp.Component()
if newName == "." || newName == ".." {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
return syserror.EBUSY
}
mnt := rp.Mount()
@@ -1251,7 +1261,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if renamed.isDir() {
if renamed == newParent || genericIsAncestorDentry(renamed, newParent) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if oldParent != newParent {
if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil {
@@ -1275,11 +1285,14 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
return syserror.ENOENT
}
replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds)
- if err != nil && err != syserror.ENOENT {
+ if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) {
return err
}
var replacedVFSD *vfs.Dentry
if replaced != nil {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 21692d2ac..c7ebd435c 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -46,6 +46,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
@@ -318,7 +319,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
if mfp == nil {
ctx.Warningf("gofer.FilesystemType.GetFilesystem: context does not provide a pgalloc.MemoryFileProvider")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
mopts := vfs.GenericParseMountOptions(opts.Data)
@@ -354,7 +355,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.interop = InteropModeShared
default:
ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: %s=%s", moptCache, cache)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
}
@@ -365,7 +366,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32)
if err != nil {
ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltUID, dfltuidstr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// In Linux, dfltuid is interpreted as a UID and is converted to a KUID
// in the caller's user namespace, but goferfs isn't
@@ -378,7 +379,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32)
if err != nil {
ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltGID, dfltgidstr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fsopts.dfltgid = auth.KGID(dfltgid)
}
@@ -390,7 +391,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
msize, err := strconv.ParseUint(msizestr, 10, 32)
if err != nil {
ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: %s=%s", moptMsize, msizestr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fsopts.msize = uint32(msize)
}
@@ -409,7 +410,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
maxCachedDentries, err := strconv.ParseUint(str, 10, 64)
if err != nil {
ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
fsopts.maxCachedDentries = maxCachedDentries
}
@@ -433,14 +434,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Check for unparsed options.
if len(mopts) != 0 {
ctx.Warningf("gofer.FilesystemType.GetFilesystem: unknown options: %v", mopts)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// Handle internal options.
iopts, ok := opts.InternalData.(InternalFilesystemOptions)
if opts.InternalData != nil && !ok {
ctx.Warningf("gofer.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted gofer.InternalFilesystemOptions", opts.InternalData)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// If !ok, iopts being the zero value is correct.
@@ -503,7 +504,7 @@ func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int
trans, ok := mopts[moptTransport]
if !ok || trans != transportModeFD {
ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as '%s=%s'", moptTransport, transportModeFD)
- return -1, syserror.EINVAL
+ return -1, linuxerr.EINVAL
}
delete(mopts, moptTransport)
@@ -511,28 +512,28 @@ func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int
rfdstr, ok := mopts[moptReadFD]
if !ok {
ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as '%s=<file descriptor>'", moptReadFD)
- return -1, syserror.EINVAL
+ return -1, linuxerr.EINVAL
}
delete(mopts, moptReadFD)
rfd, err := strconv.Atoi(rfdstr)
if err != nil {
ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: %s=%s", moptReadFD, rfdstr)
- return -1, syserror.EINVAL
+ return -1, linuxerr.EINVAL
}
wfdstr, ok := mopts[moptWriteFD]
if !ok {
ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as '%s=<file descriptor>'", moptWriteFD)
- return -1, syserror.EINVAL
+ return -1, linuxerr.EINVAL
}
delete(mopts, moptWriteFD)
wfd, err := strconv.Atoi(wfdstr)
if err != nil {
ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: %s=%s", moptWriteFD, wfdstr)
- return -1, syserror.EINVAL
+ return -1, linuxerr.EINVAL
}
if rfd != wfd {
ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
- return -1, syserror.EINVAL
+ return -1, linuxerr.EINVAL
}
return rfd, nil
}
@@ -1110,7 +1111,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
case linux.S_IFDIR:
return syserror.EISDIR
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
@@ -1282,9 +1283,12 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes)
}
func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error {
- // We only support xattrs prefixed with "user." (see b/148380782). Currently,
- // there is no need to expose any other xattrs through a gofer.
- if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ // Deny access to the "security" and "system" namespaces since applications
+ // may expect these to affect kernel behavior in unimplemented ways
+ // (b/148380782). Allow all other extended attributes to be passed through
+ // to the remote filesystem. This is inconsistent with Linux's 9p client,
+ // but consistent with other filesystems (e.g. FUSE).
+ if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) {
return syserror.EOPNOTSUPP
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
@@ -1684,7 +1688,7 @@ func (d *dentry) setDeleted() {
}
func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
- if d.file.isNil() || !d.userXattrSupported() {
+ if d.file.isNil() {
return nil, nil
}
xattrMap, err := d.file.listXattr(ctx, size)
@@ -1693,10 +1697,7 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui
}
xattrs := make([]string, 0, len(xattrMap))
for x := range xattrMap {
- // We only support xattrs in the user.* namespace.
- if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
- xattrs = append(xattrs, x)
- }
+ xattrs = append(xattrs, x)
}
return xattrs, nil
}
@@ -1731,13 +1732,6 @@ func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name
return d.file.removeXattr(ctx, name)
}
-// Extended attributes in the user.* namespace are only supported for regular
-// files and directories.
-func (d *dentry) userXattrSupported() bool {
- filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType()
- return filetype == linux.ModeRegular || filetype == linux.ModeDirectory
-}
-
// Preconditions:
// * !d.isSynthetic().
// * d.isRegularFile() || d.isDir().
@@ -1770,7 +1764,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
openReadable := !d.readFile.isNil() || read
openWritable := !d.writeFile.isNil() || write
h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc)
- if err == syserror.EACCES && (openReadable != read || openWritable != write) {
+ if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) {
// It may not be possible to use a single handle for both
// reading and writing, since permissions on the file may have
// changed to e.g. disallow reading after previously being
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
index c7bf10007..398288ee3 100644
--- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -78,7 +79,7 @@ func nonblockingPipeHasWriter(fd int32) (bool, error) {
defer tempPipeMu.Unlock()
// Copy 1 byte from fd into the temporary pipe.
n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK)
- if err == syserror.EAGAIN {
+ if linuxerr.Equals(linuxerr.EAGAIN, err) {
// The pipe represented by fd is empty, but has a writer.
return true, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 0a954c138..89eab04cd 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
@@ -60,7 +61,6 @@ func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD,
return nil, err
}
if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
- fsmetric.GoferOpensWX.Increment()
metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file")
}
if atomic.LoadInt32(&d.mmapFD) >= 0 {
@@ -125,7 +125,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
}()
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Check that flags are supported.
@@ -195,7 +195,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// offset should be ignored by PWrite.
func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, offset, syserror.EINVAL
+ return 0, offset, linuxerr.EINVAL
}
// Check that flags are supported.
@@ -298,7 +298,7 @@ func (fd *regularFileFD) writeCache(ctx context.Context, d *dentry, offset int64
pgstart := hostarch.PageRoundDown(uint64(offset))
pgend, ok := hostarch.PageRoundUp(uint64(offset + src.NumBytes()))
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
mr := memmap.MappableRange{pgstart, pgend}
var freed []memmap.FileRange
@@ -663,10 +663,10 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6
offset = size
}
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
return offset, nil
}
@@ -679,28 +679,28 @@ func (fd *regularFileFD) Sync(ctx context.Context) error {
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
d := fd.dentry()
- switch d.fs.opts.interop {
- case InteropModeExclusive:
- // Any mapping is fine.
- case InteropModeWritethrough:
- // Shared writable mappings require a host FD, since otherwise we can't
- // synchronously flush memory-mapped writes to the remote file.
- if opts.Private || !opts.MaxPerms.Write {
- break
- }
- fallthrough
- case InteropModeShared:
- // All mappings require a host FD to be coherent with other filesystem
- // users.
- if d.fs.opts.forcePageCache {
- // Whether or not we have a host FD, we're not allowed to use it.
- return syserror.ENODEV
- }
- if atomic.LoadInt32(&d.mmapFD) < 0 {
- return syserror.ENODEV
+ // Force sentry page caching at your own risk.
+ if !d.fs.opts.forcePageCache {
+ switch d.fs.opts.interop {
+ case InteropModeExclusive:
+ // Any mapping is fine.
+ case InteropModeWritethrough:
+ // Shared writable mappings require a host FD, since otherwise we
+ // can't synchronously flush memory-mapped writes to the remote
+ // file.
+ if opts.Private || !opts.MaxPerms.Write {
+ break
+ }
+ fallthrough
+ case InteropModeShared:
+ // All mappings require a host FD to be coherent with other
+ // filesystem users.
+ if atomic.LoadInt32(&d.mmapFD) < 0 {
+ return syserror.ENODEV
+ }
+ default:
+ panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop))
}
- default:
- panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop))
}
// After this point, d may be used as a memmap.Mappable.
d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
@@ -709,12 +709,12 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
}
func (d *dentry) mayCachePages() bool {
- if d.fs.opts.interop == InteropModeShared {
- return false
- }
if d.fs.opts.forcePageCache {
return true
}
+ if d.fs.opts.interop == InteropModeShared {
+ return false
+ }
return atomic.LoadInt32(&d.mmapFD) >= 0
}
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
index 83e841a51..e67422a2f 100644
--- a/pkg/sentry/fsimpl/gofer/save_restore.go
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -21,13 +21,13 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
type saveRestoreContextID int
@@ -92,7 +92,7 @@ func (fd *specialFileFD) savePipeData(ctx context.Context) error {
fd.buf = append(fd.buf, buf[:n]...)
}
if err != nil {
- if err == io.EOF || err == syserror.EAGAIN {
+ if err == io.EOF || linuxerr.Equals(linuxerr.EAGAIN, err) {
break
}
return err
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index dc019ebd5..2a922d120 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
@@ -101,7 +102,6 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci
d.fs.specialFileFDs[fd] = struct{}{}
d.fs.syncMu.Unlock()
if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) {
- fsmetric.GoferOpensWX.Increment()
metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file")
}
if h.fd >= 0 {
@@ -184,7 +184,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
}()
if fd.seekable && offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Check that flags are supported.
@@ -229,7 +229,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// Just buffer the read instead.
buf := make([]byte, dst.NumBytes())
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
- if err == syserror.EAGAIN {
+ if linuxerr.Equals(linuxerr.EAGAIN, err) {
err = syserror.ErrWouldBlock
}
if n == 0 {
@@ -264,7 +264,7 @@ func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// offset should be ignored by PWrite.
func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if fd.seekable && offset < 0 {
- return 0, offset, syserror.EINVAL
+ return 0, offset, linuxerr.EINVAL
}
// Check that flags are supported.
@@ -317,7 +317,7 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
return 0, offset, copyErr
}
n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset))
- if err == syserror.EAGAIN {
+ if linuxerr.Equals(linuxerr.EAGAIN, err) {
err = syserror.ErrWouldBlock
}
// Update offset if the offset is valid.
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index b94dfeb7f..476545d00 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -45,10 +45,10 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fdnotifier",
"//pkg/fspath",
"//pkg/hostarch",
- "//pkg/iovec",
"//pkg/log",
"//pkg/marshal/primitive",
"//pkg/refs",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index a81f550b1..4d2b282a0 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -24,6 +24,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/hostarch"
@@ -109,7 +110,7 @@ type inode struct {
func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) {
// Determine if hostFD is seekable.
_, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
- seekable := err != syserror.ESPIPE
+ seekable := !linuxerr.Equals(linuxerr.ESPIPE, err)
// We expect regular files to be seekable, as this is required for them to
// be memory-mappable.
if !seekable && fileType == unix.S_IFREG {
@@ -289,10 +290,10 @@ func (i *inode) Mode() linux.FileMode {
// Stat implements kernfs.Inode.Stat.
func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
if opts.Mask&linux.STATX__RESERVED != 0 {
- return linux.Statx{}, syserror.EINVAL
+ return linux.Statx{}, linuxerr.EINVAL
}
if opts.Sync&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
- return linux.Statx{}, syserror.EINVAL
+ return linux.Statx{}, linuxerr.EINVAL
}
fs := vfsfs.Impl().(*filesystem)
@@ -301,7 +302,7 @@ func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOp
mask := opts.Mask & linux.STATX_ALL
var s unix.Statx_t
err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s)
- if err == syserror.ENOSYS {
+ if linuxerr.Equals(linuxerr.ENOSYS, err) {
// Fallback to fstat(2), if statx(2) is not supported on the host.
//
// TODO(b/151263641): Remove fallback.
@@ -425,7 +426,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
if m&linux.STATX_SIZE != 0 {
if hostStat.Mode&linux.S_IFMT != linux.S_IFREG {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if err := unix.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
return err
@@ -730,7 +731,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
switch whence {
case linux.SEEK_SET:
if offset < 0 {
- return f.offset, syserror.EINVAL
+ return f.offset, linuxerr.EINVAL
}
f.offset = offset
@@ -740,7 +741,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
return f.offset, syserror.EOVERFLOW
}
if f.offset+offset < 0 {
- return f.offset, syserror.EINVAL
+ return f.offset, linuxerr.EINVAL
}
f.offset += offset
@@ -756,7 +757,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
return f.offset, syserror.EOVERFLOW
}
if size+offset < 0 {
- return f.offset, syserror.EINVAL
+ return f.offset, linuxerr.EINVAL
}
f.offset = size + offset
@@ -773,7 +774,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
default:
// Invalid whence.
- return f.offset, syserror.EINVAL
+ return f.offset, linuxerr.EINVAL
}
return f.offset, nil
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index ca85f5601..8cce36212 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
@@ -160,7 +161,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess
// block (and only for stream sockets).
err = syserror.EAGAIN
}
- if n > 0 && err != syserror.EAGAIN {
+ if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) {
// The caller may need to block to send more data, but
// otherwise there isn't anything that can be done about an
// error with a partial write.
diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go
index b123a63ee..e090bb725 100644
--- a/pkg/sentry/fsimpl/host/socket_iovec.go
+++ b/pkg/sentry/fsimpl/host/socket_iovec.go
@@ -16,7 +16,7 @@ package host
import (
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/iovec"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -70,7 +70,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
}
}
- if iovsRequired > iovec.MaxIovs {
+ if iovsRequired > hostfd.MaxSendRecvMsgIov {
// The kernel will reject our call if we pass this many iovs.
// Use a single intermediate buffer instead.
b := make([]byte, stopLen)
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
index 0f9e20a84..c7bf563f0 100644
--- a/pkg/sentry/fsimpl/host/tty.go
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -17,6 +17,7 @@ package host
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -211,7 +212,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch
if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
// drivers/tty/tty_io.c:tiocspgrp() converts -EIO from tty_check_change()
// to -ENOTTY.
- if err == syserror.EIO {
+ if linuxerr.Equals(linuxerr.EIO, err) {
return 0, syserror.ENOTTY
}
return 0, err
@@ -230,7 +231,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch
// pgID must be non-negative.
if pgID < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Process group with pgID must exist in this PID namespace.
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
index 63b465859..95d7ebe2e 100644
--- a/pkg/sentry/fsimpl/host/util.go
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -17,7 +17,7 @@ package host
import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
)
func toTimespec(ts linux.StatxTimestamp, omit bool) unix.Timespec {
@@ -44,5 +44,5 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
// If so, they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
- return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK
+ return linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err)
}
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index b7d13cced..d53937db6 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -104,6 +104,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/log",
@@ -135,6 +136,7 @@ go_test(
":kernfs",
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index e55111af0..8b008dc10 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -248,10 +249,10 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int
panic(fmt.Sprintf("Invalid GenericDirectoryFD.seekEnd = %v", fd.seekEnd))
}
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
fd.off = offset
return offset, nil
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index f50b0fb08..1a314f59e 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
@@ -411,7 +412,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
defer rp.Mount().EndWrite()
childI, err := parent.inode.NewDir(ctx, pc, opts)
if err != nil {
- if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
+ if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) {
return err
}
childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode)
@@ -546,7 +547,7 @@ afterTrailingSymlink:
}
// Determine whether or not we need to create a file.
child, err := fs.stepExistingLocked(ctx, rp, parent, false /* mayFollowSymlinks */)
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
// Already checked for searchability above; now check for writability.
if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
@@ -622,7 +623,7 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
}
if !d.isSymlink() {
fs.mu.RUnlock()
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
// Inode.Readlink() cannot be called holding fs locks.
@@ -635,12 +636,6 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- // Only RENAME_NOREPLACE is supported.
- if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
- return syserror.EINVAL
- }
- noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
-
fs.mu.Lock()
defer fs.processDeferredDecRefs(ctx)
defer fs.mu.Unlock()
@@ -651,6 +646,13 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
+
+ // Only RENAME_NOREPLACE is supported.
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return linuxerr.EINVAL
+ }
+ noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
+
mnt := rp.Mount()
if mnt != oldParentVD.Mount() {
return syserror.EXDEV
@@ -683,10 +685,12 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
return syserror.EBUSY
}
- switch err := checkCreateLocked(ctx, rp.Credentials(), newName, dstDir); err {
- case nil:
+
+ err = checkCreateLocked(ctx, rp.Credentials(), newName, dstDir)
+ switch {
+ case err == nil:
// Ok, continue with rename as replacement.
- case syserror.EEXIST:
+ case linuxerr.Equals(linuxerr.EEXIST, err):
if noReplace {
// Won't overwrite existing node since RENAME_NOREPLACE was requested.
return syserror.EEXIST
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 3d0866ecf..62872946e 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -158,12 +159,12 @@ type InodeNotSymlink struct{}
// Readlink implements Inode.Readlink.
func (InodeNotSymlink) Readlink(context.Context, *vfs.Mount) (string, error) {
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
// Getlink implements Inode.Getlink.
func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, string, error) {
- return vfs.VirtualDentry{}, "", syserror.EINVAL
+ return vfs.VirtualDentry{}, "", linuxerr.EINVAL
}
// InodeAttrs partially implements the Inode interface, specifically the
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 1cd3137e6..de046ce1f 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -22,6 +22,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
@@ -318,10 +319,10 @@ func TestDirFDReadWrite(t *testing.T) {
defer fd.DecRef(sys.Ctx)
// Read/Write should fail for directory FDs.
- if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR {
+ if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) {
t.Fatalf("Read for directory FD failed with unexpected error: %v", err)
}
- if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EBADF {
+ if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); !linuxerr.Equals(linuxerr.EBADF, err) {
t.Fatalf("Write for directory FD failed with unexpected error: %v", err)
}
}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
index 5504476c8..ed730e215 100644
--- a/pkg/sentry/fsimpl/overlay/BUILD
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -29,6 +29,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/log",
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 45aa5a494..8fd51e9d0 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -349,7 +350,7 @@ func (d *dentry) copyXattrsLocked(ctx context.Context) error {
lowerXattrs, err := vfsObj.ListXattrAt(ctx, d.fs.creds, lowerPop, 0)
if err != nil {
- if err == syserror.EOPNOTSUPP {
+ if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) {
// There are no guarantees as to the contents of lowerXattrs.
return nil
}
diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go
index df4492346..417a7c630 100644
--- a/pkg/sentry/fsimpl/overlay/directory.go
+++ b/pkg/sentry/fsimpl/overlay/directory.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -256,7 +257,7 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
switch whence {
case linux.SEEK_SET:
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset == 0 {
// Ensure that the next call to fd.IterDirents() calls
@@ -268,13 +269,13 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
case linux.SEEK_CUR:
offset += fd.off
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Don't clear fd.dirents in this case, even if offset == 0.
fd.off = offset
return fd.off, nil
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 46c500427..e792677f5 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -218,7 +219,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
Start: parentVD,
Path: childPath,
}, &vfs.GetDentryOptions{})
- if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENAMETOOLONG, err) {
// The file doesn't exist on this layer. Proceed to the next one.
return true
}
@@ -352,7 +353,7 @@ func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, nam
}, &vfs.StatOptions{
Mask: linux.STATX_TYPE,
})
- if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENAMETOOLONG, err) {
// The file doesn't exist on this layer. Proceed to the next
// one.
return true
@@ -811,7 +812,7 @@ afterTrailingSymlink:
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
- if err == syserror.ENOENT && mayCreate {
+ if linuxerr.Equals(linuxerr.ENOENT, err) && mayCreate {
fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout)
parent.dirMu.Unlock()
return fd, err
@@ -871,7 +872,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *
return nil, syserror.EISDIR
}
if opts.Flags&linux.O_DIRECT != 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
fd := &directoryFD{}
fd.LockFD.Init(&d.locks)
@@ -1017,10 +1018,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- if opts.Flags != 0 {
- return syserror.EINVAL
- }
-
+ // Resolve newParent first to verify that it's on this Mount.
var ds *[]*dentry
fs.renameMu.Lock()
defer fs.renameMuUnlockAndCheckDrop(ctx, &ds)
@@ -1028,8 +1026,16 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
+
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return linuxerr.EINVAL
+ }
+
newName := rp.Component()
if newName == "." || newName == ".." {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
return syserror.EBUSY
}
mnt := rp.Mount()
@@ -1059,7 +1065,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
if renamed.isDir() {
if renamed == newParent || genericIsAncestorDentry(renamed, newParent) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if oldParent != newParent {
if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil {
@@ -1089,10 +1095,13 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
whiteouts map[string]bool
)
replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds)
- if err != nil && err != syserror.ENOENT {
+ if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) {
return err
}
if replaced != nil {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
@@ -1169,7 +1178,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
Root: replaced.upperVD,
Start: replaced.upperVD,
Path: fspath.Parse(whiteoutName),
- }); err != nil && err != syserror.EEXIST {
+ }); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RenameAt failure: %v", err))
}
}
@@ -1277,7 +1286,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
defer rp.Mount().EndWrite()
name := rp.Component()
if name == "." {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if name == ".." {
return syserror.ENOTEMPTY
@@ -1336,7 +1345,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
Root: child.upperVD,
Start: child.upperVD,
Path: fspath.Parse(whiteoutName),
- }); err != nil && err != syserror.EEXIST {
+ }); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err))
}
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 454c20d4f..4c7243764 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -40,6 +40,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -135,7 +136,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts, ok := fsoptsRaw.(FilesystemOptions)
if fsoptsRaw != nil && !ok {
ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
vfsroot := vfs.RootFromContext(ctx)
if vfsroot.Ok() {
@@ -145,7 +146,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if upperPathname, ok := mopts["upperdir"]; ok {
if fsopts.UpperRoot.Ok() {
ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
delete(mopts, "upperdir")
// Linux overlayfs also requires a workdir when upperdir is
@@ -154,7 +155,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
upperPath := fspath.Parse(upperPathname)
if !upperPath.Absolute {
ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
Root: vfsroot,
@@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if lowerPathnamesStr, ok := mopts["lowerdir"]; ok {
if len(fsopts.LowerRoots) != 0 {
ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
delete(mopts, "lowerdir")
lowerPathnames := strings.Split(lowerPathnamesStr, ":")
@@ -189,7 +190,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
lowerPath := fspath.Parse(lowerPathname)
if !lowerPath.Absolute {
ctx.Infof("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
Root: vfsroot,
@@ -216,21 +217,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if len(mopts) != 0 {
ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
if len(fsopts.LowerRoots) == 0 {
ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
if len(fsopts.LowerRoots) > maxLowerLayers {
ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// Take extra references held by the filesystem.
@@ -283,7 +284,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Infof("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout")
root.destroyLocked(ctx)
fs.vfsfs.DecRef(ctx)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
root.mode = uint32(rootStat.Mode)
root.uid = rootStat.UID
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 43bfd69a3..82491a0f8 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -207,9 +207,10 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e
return err
}
- // Changing owners may clear one or both of the setuid and setgid bits,
- // so we may have to update opts before setting d.mode.
- if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ // Changing owners or truncating may clear one or both of the setuid and
+ // setgid bits, so we may have to update opts before setting d.mode.
+ inotifyMask := opts.Stat.Mask
+ if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 {
stat, err := wrappedFD.Stat(ctx, vfs.StatOptions{
Mask: linux.STATX_MODE,
})
@@ -218,10 +219,14 @@ func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) e
}
opts.Stat.Mode = stat.Mode
opts.Stat.Mask |= linux.STATX_MODE
+ // Don't generate inotify IN_ATTRIB for size-only changes (truncations).
+ if opts.Stat.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ inotifyMask |= linux.STATX_MODE
+ }
}
d.updateAfterSetStatLocked(&opts)
- if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ if ev := vfs.InotifyEventFromStatMask(inotifyMask); ev != 0 {
d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
}
return nil
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 2b628bd55..1d3d2d95f 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -81,6 +81,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
@@ -119,6 +120,7 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/testutil",
@@ -127,7 +129,6 @@ go_test(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index ce8f55b1f..f2697c12d 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -21,11 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -76,7 +76,7 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
if err != nil {
ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index b294dfd6a..9187f5b11 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -325,7 +326,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in
// the file ..." - user_namespaces(7)
srclen := src.NumBytes()
if srclen >= hostarch.PageSize || offset != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
b := make([]byte, srclen)
if _, err := src.CopyIn(ctx, b); err != nil {
@@ -345,7 +346,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in
}
lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1)
if len(lines) > maxIDMapLines {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
entries := make([]auth.IDMapEntry, len(lines))
@@ -353,7 +354,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in
var e auth.IDMapEntry
_, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length)
if err != nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
entries[i] = e
}
@@ -461,10 +462,10 @@ func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, e
case linux.SEEK_CUR:
offset += fd.offset
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
fd.offset = offset
return offset, nil
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 177cb828f..ab47ea5a7 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -33,7 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -679,7 +679,7 @@ func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error {
continue
}
if err := d.stack.Statistics(stat, line.prefix); err != nil {
- if err == syserror.EOPNOTSUPP {
+ if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) {
log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
} else {
log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 045ed7a2d..2def1ca48 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -53,7 +54,7 @@ func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error
t := kernel.TaskFromContext(ctx)
if t == nil {
// Who is reading this link?
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
if tgid == 0 {
@@ -94,7 +95,7 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string,
t := kernel.TaskFromContext(ctx)
if t == nil {
// Who is reading this link?
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
tid := s.pidns.IDOfTask(t)
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 88ab49048..99f64a9d8 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -55,6 +55,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
}),
}),
"vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "max_map_count": fs.newInode(ctx, root, 0444, newStaticFile("2147483647\n")),
"mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}),
"overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")),
}),
@@ -208,7 +209,7 @@ func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error {
func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if src.NumBytes() == 0 {
return 0, nil
@@ -256,7 +257,7 @@ func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error
func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if src.NumBytes() == 0 {
return 0, nil
@@ -310,7 +311,7 @@ func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error {
func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if src.NumBytes() == 0 {
return 0, nil
@@ -395,7 +396,7 @@ func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error
func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if src.NumBytes() == 0 {
return 0, nil
@@ -448,7 +449,7 @@ func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error {
func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// No need to handle partial writes thus far.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if src.NumBytes() == 0 {
return 0, nil
@@ -466,7 +467,7 @@ func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset i
// Port numbers must be uint16s.
if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if err := pr.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil {
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index e534fbca8..14f806c3c 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -23,13 +23,13 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -227,7 +227,7 @@ func TestTasks(t *testing.T) {
defer fd.DecRef(s.Ctx)
buf := make([]byte, 1)
bufIOSeq := usermem.BytesIOSequence(buf)
- if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
+ if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); !linuxerr.Equals(linuxerr.EISDIR, err) {
t.Errorf("wrong error reading directory: %v", err)
}
}
@@ -237,7 +237,7 @@ func TestTasks(t *testing.T) {
s.Creds,
s.PathOpAtRoot("/proc/9999"),
&vfs.OpenOptions{},
- ); err != syserror.ENOENT {
+ ); !linuxerr.Equals(linuxerr.ENOENT, err) {
t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err)
}
}
diff --git a/pkg/sentry/fsimpl/proc/yama.go b/pkg/sentry/fsimpl/proc/yama.go
index e039ec45e..7240563d7 100644
--- a/pkg/sentry/fsimpl/proc/yama.go
+++ b/pkg/sentry/fsimpl/proc/yama.go
@@ -21,11 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -56,7 +56,7 @@ func (s *yamaPtraceScope) Generate(ctx context.Context, buf *bytes.Buffer) error
func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
if offset != 0 {
// Ignore partial writes.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if src.NumBytes() == 0 {
return 0, nil
@@ -73,7 +73,7 @@ func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, off
// We do not support YAMA levels > YAMA_SCOPE_RELATIONAL.
if v < linux.YAMA_SCOPE_DISABLED || v > linux.YAMA_SCOPE_RELATIONAL {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
atomic.StoreInt32(s.level, v)
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index 09043b572..1af0a5cbc 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -26,6 +26,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/coverage",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index b13f141a8..d06aea162 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -17,6 +17,7 @@ package sys
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -85,7 +86,7 @@ func (fd *kcovFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallAr
case linux.KCOV_DISABLE:
if arg != 0 {
// This arg is unused; it should be 0.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
return 0, fd.kcov.DisableTrace(ctx)
default:
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 14eb10dcd..546f54a5a 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/coverage"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -74,7 +75,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
if err != nil {
ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index c766164c7..b3f9d1010 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -17,7 +17,6 @@ go_library(
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/memutil",
- "//pkg/metric",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 33e52ce64..473b41cff 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/memutil"
- "gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -63,8 +62,6 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("creating platform: %v", err)
}
- metric.CreateSentryMetrics()
-
kernel.VFS2Enabled = true
k := &kernel.Kernel{
Platform: plat,
@@ -83,10 +80,7 @@ func Boot() (*kernel.Kernel, error) {
}
// Create timekeeper.
- tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange())
- if err != nil {
- return nil, fmt.Errorf("creating timekeeper: %v", err)
- }
+ tk := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange())
tk.SetClocks(time.NewCalibratedClocks())
creds := auth.NewRootCredentials(auth.NewRootUserNamespace())
@@ -181,7 +175,7 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) {
memfile := os.NewFile(uintptr(memfd), memfileName)
mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
if err != nil {
- memfile.Close()
+ _ = memfile.Close()
return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err)
}
return mf, nil
diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD
index 7ce7dc429..e6980a314 100644
--- a/pkg/sentry/fsimpl/timerfd/BUILD
+++ b/pkg/sentry/fsimpl/timerfd/BUILD
@@ -8,6 +8,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/sentry/kernel/time",
"//pkg/sentry/vfs",
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
index cbb8b67c5..655a1c76a 100644
--- a/pkg/sentry/fsimpl/timerfd/timerfd.go
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -69,7 +70,7 @@ func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock,
func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
const sizeofUint64 = 8
if dst.NumBytes() < sizeofUint64 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if val := atomic.SwapUint64(&tfd.val, 0); val != 0 {
var buf [sizeofUint64]byte
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index e21fddd7f..ae612aae0 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -58,6 +58,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/log",
@@ -118,6 +119,7 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs/lock",
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index e8d256495..c25494c0b 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -19,10 +19,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// +stateify savable
@@ -196,10 +196,10 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
case linux.SEEK_CUR:
offset += fd.off
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// If the offset isn't changing (e.g. due to lseek(0, SEEK_CUR)), don't
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 766289e60..590f7118a 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -300,7 +301,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
case linux.S_IFSOCK:
childInode = fs.newSocketFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, opts.Endpoint, parentDir)
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
child := fs.newDentry(childInode)
parentDir.insertChildLocked(child, name)
@@ -488,7 +489,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
}
symlink, ok := d.inode.impl.(*symlink)
if !ok {
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
symlink.inode.touchAtime(rp.Mount())
return symlink.target, nil
@@ -496,20 +497,24 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// RenameAt implements vfs.FilesystemImpl.RenameAt.
func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
- if opts.Flags != 0 {
- // TODO(b/145974740): Support renameat2 flags.
- return syserror.EINVAL
- }
-
- // Resolve newParent first to verify that it's on this Mount.
+ // Resolve newParentDir first to verify that it's on this Mount.
fs.mu.Lock()
defer fs.mu.Unlock()
newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
if err != nil {
return err
}
+
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ // TODO(b/145974740): Support other renameat2 flags.
+ return linuxerr.EINVAL
+ }
+
newName := rp.Component()
if newName == "." || newName == ".." {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
return syserror.EBUSY
}
mnt := rp.Mount()
@@ -537,7 +542,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// mounted filesystem.
if renamed.inode.isDir() {
if renamed == &newParentDir.dentry || genericIsAncestorDentry(renamed, &newParentDir.dentry) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if oldParentDir != newParentDir {
// Writability is needed to change renamed's "..".
@@ -556,6 +561,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
replaced, ok := newParentDir.childMap[newName]
if ok {
+ if opts.Flags&linux.RENAME_NOREPLACE != 0 {
+ return syserror.EEXIST
+ }
replacedDir, ok := replaced.inode.impl.(*directory)
if ok {
if !renamed.inode.isDir() {
@@ -639,7 +647,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
name := rp.Component()
if name == "." {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if name == ".." {
return syserror.ENOTEMPTY
@@ -815,7 +823,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
if err != nil {
return nil, err
}
- return d.inode.listXattr(size)
+ return d.inode.listXattr(rp.Credentials(), size)
}
// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 2f856ce36..418c7994e 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -114,7 +115,7 @@ func TestNonblockingWriteError(t *testing.T) {
}
openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK}
_, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
- if err != syserror.ENXIO {
+ if !linuxerr.Equals(linuxerr.ENXIO, err) {
t.Fatalf("expected ENXIO, but got error: %v", err)
}
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index c45bddff6..0bc1911d9 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -366,7 +367,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
fsmetric.TmpfsReads.Increment()
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
@@ -407,7 +408,7 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// final offset should be ignored by PWrite.
func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, offset, syserror.EINVAL
+ return 0, offset, linuxerr.EINVAL
}
// Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
@@ -432,7 +433,7 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
if end := offset + srclen; end < offset {
// Overflow.
- return 0, offset, syserror.EINVAL
+ return 0, offset, linuxerr.EINVAL
}
srclen, err = vfs.CheckLimit(ctx, offset, srclen)
@@ -476,10 +477,10 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
case linux.SEEK_END:
offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size))
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
fd.off = offset
return offset, nil
@@ -684,7 +685,7 @@ exitLoop:
func GetSeals(fd *vfs.FileDescription) (uint32, error) {
f, ok := fd.Impl().(*regularFileFD)
if !ok {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
rf := f.inode().impl.(*regularFile)
rf.dataMu.RLock()
@@ -696,7 +697,7 @@ func GetSeals(fd *vfs.FileDescription) (uint32, error) {
func AddSeals(fd *vfs.FileDescription, val uint32) error {
f, ok := fd.Impl().(*regularFileFD)
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rf := f.inode().impl.(*regularFile)
rf.mapsMu.Lock()
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 9ae25ce9e..bc40aad0d 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -36,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -138,7 +139,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mode, err := strconv.ParseUint(modeStr, 8, 32)
if err != nil {
ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid mode: %q", modeStr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
rootMode = linux.FileMode(mode & 07777)
}
@@ -149,12 +150,12 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
uid, err := strconv.ParseUint(uidStr, 10, 32)
if err != nil {
ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid uid: %q", uidStr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
if !kuid.Ok() {
ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
rootKUID = kuid
}
@@ -165,18 +166,18 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
gid, err := strconv.ParseUint(gidStr, 10, 32)
if err != nil {
ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid gid: %q", gidStr)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
if !kgid.Ok() {
ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
rootKGID = kgid
}
if len(mopts) != 0 {
ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
devMinor, err := vfsObj.GetAnonBlockDevMinor()
@@ -557,7 +558,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.
case *directory:
return syserror.EISDIR
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
if mask&linux.STATX_UID != 0 {
@@ -717,44 +718,63 @@ func (i *inode) touchCMtimeLocked() {
atomic.StoreInt64(&i.ctime, now)
}
-func (i *inode) listXattr(size uint64) ([]string, error) {
- return i.xattrs.ListXattr(size)
+func checkXattrName(name string) error {
+ // Linux's tmpfs supports "security" and "trusted" xattr namespaces, and
+ // (depending on build configuration) POSIX ACL xattr namespaces
+ // ("system.posix_acl_access" and "system.posix_acl_default"). We don't
+ // support POSIX ACLs or the "security" namespace (b/148380782).
+ if strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
+ return nil
+ }
+ // We support the "user" namespace because we have tests that depend on
+ // this feature.
+ if strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return nil
+ }
+ return syserror.EOPNOTSUPP
+}
+
+func (i *inode) listXattr(creds *auth.Credentials, size uint64) ([]string, error) {
+ return i.xattrs.ListXattr(creds, size)
}
func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
- if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
+ if err := checkXattrName(opts.Name); err != nil {
return "", err
}
- return i.xattrs.GetXattr(opts)
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ kuid := auth.KUID(atomic.LoadUint32(&i.uid))
+ kgid := auth.KGID(atomic.LoadUint32(&i.gid))
+ if err := vfs.GenericCheckPermissions(creds, vfs.MayRead, mode, kuid, kgid); err != nil {
+ return "", err
+ }
+ return i.xattrs.GetXattr(creds, mode, kuid, opts)
}
func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
- if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
+ if err := checkXattrName(opts.Name); err != nil {
return err
}
- return i.xattrs.SetXattr(opts)
-}
-
-func (i *inode) removeXattr(creds *auth.Credentials, name string) error {
- if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ kuid := auth.KUID(atomic.LoadUint32(&i.uid))
+ kgid := auth.KGID(atomic.LoadUint32(&i.gid))
+ if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil {
return err
}
- return i.xattrs.RemoveXattr(name)
+ return i.xattrs.SetXattr(creds, mode, kuid, opts)
}
-func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error {
- // We currently only support extended attributes in the user.* and
- // trusted.* namespaces. See b/148380782.
- if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
- return syserror.EOPNOTSUPP
+func (i *inode) removeXattr(creds *auth.Credentials, name string) error {
+ if err := checkXattrName(name); err != nil {
+ return err
}
mode := linux.FileMode(atomic.LoadUint32(&i.mode))
kuid := auth.KUID(atomic.LoadUint32(&i.uid))
kgid := auth.KGID(atomic.LoadUint32(&i.gid))
- if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil {
+ if err := vfs.GenericCheckPermissions(creds, vfs.MayWrite, mode, kuid, kgid); err != nil {
return err
}
- return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name)
+ return i.xattrs.RemoveXattr(creds, mode, kuid, name)
}
// fileDescription is embedded by tmpfs implementations of
@@ -807,7 +827,7 @@ func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
- return fd.inode().listXattr(size)
+ return fd.inode().listXattr(auth.CredentialsFromContext(ctx), size)
}
// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index d473a922d..1d855234c 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -13,6 +13,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/marshal/primitive",
@@ -41,6 +42,7 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/sentry/arch",
"//pkg/sentry/fsimpl/testutil",
@@ -48,7 +50,6 @@ go_test(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 3582d14c9..b5735a86d 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/merkletree"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -195,7 +196,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// The Merkle tree file for the child should have been created and
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
- if err == syserror.ENOENT || err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) {
return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
@@ -218,7 +219,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
@@ -238,7 +239,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// The Merkle tree file for the child should have been created and
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
- if err == syserror.ENOENT || err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) {
return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
@@ -261,7 +262,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
Root: parent.lowerVD,
Start: parent.lowerVD,
}, &vfs.StatOptions{})
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
}
if err != nil {
@@ -282,7 +283,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
Mode: uint32(parentStat.Mode),
UID: parentStat.UID,
GID: parentStat.GID,
- Children: parent.childrenNames,
+ Children: parent.childrenList,
HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: int64(offset),
ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
@@ -327,7 +328,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
}, &vfs.OpenOptions{
Flags: linux.O_RDONLY,
})
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
}
if err != nil {
@@ -341,7 +342,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
Size: sizeOfStringInt32,
})
- if err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENODATA, err) {
return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
@@ -359,7 +360,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
Size: sizeOfStringInt32,
})
- if err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENODATA, err) {
return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err))
}
if err != nil {
@@ -375,7 +376,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
Size: sizeOfStringInt32,
})
- if err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENODATA, err) {
return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err))
}
if err != nil {
@@ -403,6 +404,9 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
var buf bytes.Buffer
d.hashMu.RLock()
+
+ d.generateChildrenList()
+
params := &merkletree.VerifyParams{
Out: &buf,
Tree: &fdReader,
@@ -411,7 +415,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
Mode: uint32(stat.Mode),
UID: stat.UID,
GID: stat.GID,
- Children: d.childrenNames,
+ Children: d.childrenList,
HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: 0,
// Set read size to 0 so only the metadata is verified.
@@ -465,7 +469,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
}
childVD, err := parent.getLowerAt(ctx, vfsObj, name)
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
// The file was previously accessed. If the
// file does not exist now, it indicates an
// unexpected modification to the file system.
@@ -480,7 +484,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// The Merkle tree file was previous accessed. If it
// does not exist now, it indicates an unexpected
// modification to the file system.
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
}
if err != nil {
@@ -551,7 +555,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
}
childVD, err := parent.getLowerAt(ctx, vfsObj, name)
- if parent.verityEnabled() && err == syserror.ENOENT {
+ if parent.verityEnabled() && linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name))
}
if err != nil {
@@ -564,7 +568,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
if err != nil {
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
if parent.verityEnabled() {
return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name))
}
@@ -854,7 +858,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// The file should exist, as we succeeded in finding its dentry. If it's
// missing, it indicates an unexpected modification to the file system.
if err != nil {
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
}
return nil, err
@@ -877,7 +881,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// dentry. If it's missing, it indicates an unexpected modification to
// the file system.
if err != nil {
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
@@ -902,7 +906,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
Flags: linux.O_WRONLY | linux.O_APPEND,
})
if err != nil {
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
@@ -919,7 +923,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
Flags: linux.O_WRONLY | linux.O_APPEND,
})
if err != nil {
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index fa7696ad6..2227b542a 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -39,12 +39,14 @@ import (
"encoding/json"
"fmt"
"math"
+ "sort"
"strconv"
"strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
@@ -251,7 +253,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
hash, err := hex.DecodeString(encodedRootHash)
if err != nil {
ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
rootHash = hash
}
@@ -269,19 +271,19 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Check for unparsed options.
if len(mopts) != 0 {
ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
// Handle internal options.
iopts, ok := opts.InternalData.(InternalFilesystemOptions)
if len(lowerPathname) == 0 && !ok {
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
if len(lowerPathname) != 0 {
if ok {
ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
iopts = InternalFilesystemOptions{
AllowRuntimeEnable: len(rootHash) == 0,
@@ -300,7 +302,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
lowerPath := fspath.Parse(lowerPathname)
if !lowerPath.Absolute {
ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
var err error
mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
@@ -358,7 +360,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// If runtime enable is allowed, the root merkle tree may be absent. We
// should create the tree file.
- if err == syserror.ENOENT && fs.allowRuntimeEnable {
+ if linuxerr.Equals(linuxerr.ENOENT, err) && fs.allowRuntimeEnable {
lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
Start: lowerVD,
@@ -439,7 +441,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if !d.isDir() {
ctx.Warningf("verity root must be a directory")
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
if !fs.allowRuntimeEnable {
@@ -451,7 +453,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Name: childrenOffsetXattr,
Size: sizeOfStringInt32,
})
- if err == syserror.ENOENT || err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) {
return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err))
}
if err != nil {
@@ -470,7 +472,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Name: childrenSizeXattr,
Size: sizeOfStringInt32,
})
- if err == syserror.ENOENT || err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENOENT, err) || linuxerr.Equals(linuxerr.ENODATA, err) {
return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err))
}
if err != nil {
@@ -487,7 +489,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}, &vfs.OpenOptions{
Flags: linux.O_RDONLY,
})
- if err == syserror.ENOENT {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err))
}
if err != nil {
@@ -508,6 +510,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
return nil, nil, err
}
+ d.generateChildrenList()
}
d.vfsd.Init(d)
@@ -564,6 +567,11 @@ type dentry struct {
// populated by enableVerity. childrenNames is also protected by dirMu.
childrenNames map[string]struct{}
+ // childrenList is a complete sorted list of childrenNames. This list
+ // is generated when verity is enabled, or the first time the file is
+ // verified in non runtime enable mode.
+ childrenList []string
+
// lowerVD is the VirtualDentry in the underlying file system. It is
// never modified after initialized.
lowerVD vfs.VirtualDentry
@@ -749,6 +757,17 @@ func (d *dentry) verityEnabled() bool {
return !d.fs.allowRuntimeEnable || len(d.hash) != 0
}
+// generateChildrenList generates a sorted childrenList from childrenNames, and
+// cache it in d for hashing.
+func (d *dentry) generateChildrenList() {
+ if len(d.childrenList) == 0 && len(d.childrenNames) != 0 {
+ for child := range d.childrenNames {
+ d.childrenList = append(d.childrenList, child)
+ }
+ sort.Strings(d.childrenList)
+ }
+}
+
// getLowerAt returns the dentry in the underlying file system, which is
// represented by filename relative to d.
func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) {
@@ -868,6 +887,10 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
fd.mu.Lock()
defer fd.mu.Unlock()
+ if _, err := fd.lowerFD.Seek(ctx, fd.off, linux.SEEK_SET); err != nil {
+ return err
+ }
+
var ds []vfs.Dirent
err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
// Do not include the Merkle tree files.
@@ -890,8 +913,8 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
return err
}
- // The result should contain all children plus "." and "..".
- if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 {
+ // The result should be a part of all children plus "." and "..", counting from fd.off.
+ if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2-int(fd.off) {
return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
}
@@ -917,14 +940,14 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32)
case linux.SEEK_END:
n = int64(fd.d.size)
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset > math.MaxInt64-n {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
offset += n
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
fd.off = offset
return offset, nil
@@ -958,10 +981,12 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui
return nil, 0, err
}
+ fd.d.generateChildrenList()
+
params := &merkletree.GenerateParams{
TreeReader: &merkleReader,
TreeWriter: &merkleWriter,
- Children: fd.d.childrenNames,
+ Children: fd.d.childrenList,
HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
Name: fd.d.name,
Mode: uint32(stat.Mode),
@@ -1003,7 +1028,7 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui
default:
// TODO(b/167728857): Investigate whether and how we should
// enable other types of file.
- return nil, 0, syserror.EINVAL
+ return nil, 0, linuxerr.EINVAL
}
hash, err := merkletree.Generate(params)
return hash, uint64(params.Size), err
@@ -1121,7 +1146,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest hostarch.Addr) (uintptr, error) {
t := kernel.TaskFromContext(ctx)
if t == nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
var metadata linux.DigestMetadata
@@ -1174,7 +1199,7 @@ func (fd *fileDescription) verityFlags(ctx context.Context, flags hostarch.Addr)
t := kernel.TaskFromContext(ctx)
if t == nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
_, err := primitive.CopyInt32Out(t, flags, f)
return 0, err
@@ -1223,7 +1248,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// The Merkle tree file for the child should have been created and
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
- if err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENODATA, err) {
return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
@@ -1257,7 +1282,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Mode: fd.d.mode,
UID: fd.d.uid,
GID: fd.d.gid,
- Children: fd.d.childrenNames,
+ Children: fd.d.childrenList,
HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
ReadOffset: offset,
ReadSize: dst.NumBytes(),
@@ -1345,7 +1370,7 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem
// The Merkle tree file for the child should have been created and
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
- if err == syserror.ENODATA {
+ if linuxerr.Equals(linuxerr.ENODATA, err) {
return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
@@ -1429,7 +1454,7 @@ func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) {
// mapped region.
readOffset := off - int64(r.Offset)
if readOffset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
bs.DropFirst64(uint64(readOffset))
view := bs.TakeFirst64(uint64(len(p)))
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 5c78a0019..65465b814 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
@@ -31,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -476,7 +476,7 @@ func TestOpenNonexistentFile(t *testing.T) {
// Ensure open an unexpected file in the parent directory fails with
// ENOENT rather than verification failure.
- if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); err != syserror.ENOENT {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.ENOENT, err) {
t.Errorf("OpenAt unexpected error: %v", err)
}
}
@@ -767,7 +767,7 @@ func TestOpenDeletedFileFails(t *testing.T) {
}
// Ensure reopening the verity enabled file fails.
- if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) {
t.Errorf("got OpenAt error: %v, expected EIO", err)
}
})
@@ -829,7 +829,7 @@ func TestOpenRenamedFileFails(t *testing.T) {
}
// Ensure reopening the verity enabled file fails.
- if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
+ if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) {
t.Errorf("got OpenAt error: %v, expected EIO", err)
}
})
@@ -1063,14 +1063,14 @@ func TestDeletedSymlinkFileReadFails(t *testing.T) {
Root: root,
Start: root,
Path: fspath.Parse(symlink),
- }); err != syserror.EIO {
+ }); !linuxerr.Equals(linuxerr.EIO, err) {
t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err)
}
if tc.testWalk {
fileInSymlinkDirectory := symlink + "/verity-test-file"
// Ensure opening the verity enabled file in the symlink directory fails.
- if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
+ if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) {
t.Errorf("Open succeeded with modified symlink: %v", err)
}
}
@@ -1195,14 +1195,14 @@ func TestModifiedSymlinkFileReadFails(t *testing.T) {
Root: root,
Start: root,
Path: fspath.Parse(symlink),
- }); err != syserror.EIO {
+ }); !linuxerr.Equals(linuxerr.EIO, err) {
t.Fatalf("ReadlinkAt succeeded with modified symlink: %v", err)
}
if tc.testWalk {
fileInSymlinkDirectory := symlink + "/verity-test-file"
// Ensure opening the verity enabled file in the symlink directory fails.
- if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO {
+ if _, err := openVerityAt(ctx, vfsObj, root, fileInSymlinkDirectory, linux.O_RDONLY, linux.ModeRegular); !linuxerr.Equals(linuxerr.EIO, err) {
t.Errorf("Open succeeded with modified symlink: %v", err)
}
}
diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go
index 7e535b527..17d0d5025 100644
--- a/pkg/sentry/fsmetric/fsmetric.go
+++ b/pkg/sentry/fsmetric/fsmetric.go
@@ -42,7 +42,6 @@ var (
// Metrics that only apply to fs/gofer and fsimpl/gofer.
var (
- GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.")
GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.")
GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.")
GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go
index 1cabc848f..e103e7296 100644
--- a/pkg/sentry/hostfd/hostfd_linux.go
+++ b/pkg/sentry/hostfd/hostfd_linux.go
@@ -14,5 +14,10 @@
package hostfd
-// maxIov is the maximum permitted size of a struct iovec array.
-const maxIov = 1024 // UIO_MAXIOV
+// MaxReadWriteIov is the maximum permitted size of a struct iovec array in a
+// readv, writev, preadv, or pwritev host syscall.
+const MaxReadWriteIov = 1024 // UIO_MAXIOV
+
+// MaxSendRecvMsgIov is the maximum permitted size of a struct iovec array in a
+// sendmsg or recvmsg host syscall.
+const MaxSendRecvMsgIov = 1024 // UIO_MAXIOV
diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go
index 03c6d2a16..a43311eb4 100644
--- a/pkg/sentry/hostfd/hostfd_unsafe.go
+++ b/pkg/sentry/hostfd/hostfd_unsafe.go
@@ -23,6 +23,11 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
)
+const (
+ sizeofIovec = unsafe.Sizeof(unix.Iovec{})
+ sizeofMsghdr = unsafe.Sizeof(unix.Msghdr{})
+)
+
// Preadv2 reads up to dsts.NumBytes() bytes from host file descriptor fd into
// dsts. offset and flags are interpreted as for preadv2(2).
//
@@ -44,9 +49,9 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6
}
} else {
iovs := safemem.IovecsFromBlockSeq(dsts)
- if len(iovs) > maxIov {
- log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
- iovs = iovs[:maxIov]
+ if len(iovs) > MaxReadWriteIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), MaxReadWriteIov)
+ iovs = iovs[:MaxReadWriteIov]
}
n, _, e = unix.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
@@ -80,9 +85,9 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint
}
} else {
iovs := safemem.IovecsFromBlockSeq(srcs)
- if len(iovs) > maxIov {
- log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
- iovs = iovs[:maxIov]
+ if len(iovs) > MaxReadWriteIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), MaxReadWriteIov)
+ iovs = iovs[:MaxReadWriteIov]
}
n, _, e = unix.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index a1ec6daab..26614b029 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -32,7 +32,7 @@ go_template_instance(
out = "seqatomic_taskgoroutineschedinfo_unsafe.go",
package = "kernel",
suffix = "TaskGoroutineSchedInfo",
- template = "//pkg/sync:generic_seqatomic",
+ template = "//pkg/sync/seqatomic:generic_seqatomic",
types = {
"Value": "TaskGoroutineSchedInfo",
},
@@ -218,6 +218,7 @@ go_library(
":uncaught_signal_go_proto",
"//pkg/abi",
"//pkg/abi/linux",
+ "//pkg/abi/linux/errno",
"//pkg/amutex",
"//pkg/bits",
"//pkg/bpf",
@@ -225,6 +226,8 @@ go_library(
"//pkg/context",
"//pkg/coverage",
"//pkg/cpuid",
+ "//pkg/errors",
+ "//pkg/errors/linuxerr",
"//pkg/eventchannel",
"//pkg/fspath",
"//pkg/goid",
@@ -298,6 +301,7 @@ go_test(
deps = [
"//pkg/abi",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/sentry/arch",
"//pkg/sentry/contexttest",
@@ -309,6 +313,5 @@ go_test(
"//pkg/sentry/time",
"//pkg/sentry/usage",
"//pkg/sync",
- "//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 869e49ebc..7a1a36454 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_credentials_unsafe.go",
package = "auth",
suffix = "Credentials",
- template = "//pkg/sync:generic_atomicptr",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
types = {
"Value": "Credentials",
},
@@ -63,6 +63,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/bits",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
index 6862f2ef5..32c344399 100644
--- a/pkg/sentry/kernel/auth/credentials.go
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -16,6 +16,7 @@ package auth
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -125,7 +126,7 @@ func NewUserCredentials(kuid KUID, kgid KGID, extraKGIDs []KGID, capabilities *T
creds.EffectiveCaps = capabilities.EffectiveCaps
creds.BoundingCaps = capabilities.BoundingCaps
creds.InheritableCaps = capabilities.InheritableCaps
- // TODO(nlacasse): Support ambient capabilities.
+ // TODO(gvisor.dev/issue/3166): Support ambient capabilities.
} else {
// If no capabilities are specified, grant capabilities consistent with
// setresuid + setresgid from NewRootCredentials to the given uid and
@@ -203,7 +204,7 @@ func (c *Credentials) UseUID(uid UID) (KUID, error) {
// uid must be mapped.
kuid := c.UserNamespace.MapToKUID(uid)
if !kuid.Ok() {
- return NoID, syserror.EINVAL
+ return NoID, linuxerr.EINVAL
}
// If c has CAP_SETUID, then it can use any UID in its user namespace.
if c.HasCapability(linux.CAP_SETUID) {
@@ -222,7 +223,7 @@ func (c *Credentials) UseUID(uid UID) (KUID, error) {
func (c *Credentials) UseGID(gid GID) (KGID, error) {
kgid := c.UserNamespace.MapToKGID(gid)
if !kgid.Ok() {
- return NoID, syserror.EINVAL
+ return NoID, linuxerr.EINVAL
}
if c.HasCapability(linux.CAP_SETGID) {
return kgid, nil
@@ -239,7 +240,7 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) {
func (c *Credentials) SetUID(uid UID) error {
kuid := c.UserNamespace.MapToKUID(uid)
if !kuid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
c.RealKUID = kuid
c.EffectiveKUID = kuid
@@ -253,7 +254,7 @@ func (c *Credentials) SetUID(uid UID) error {
func (c *Credentials) SetGID(gid GID) error {
kgid := c.UserNamespace.MapToKGID(gid)
if !kgid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
c.RealKGID = kgid
c.EffectiveKGID = kgid
diff --git a/pkg/sentry/kernel/auth/id_map.go b/pkg/sentry/kernel/auth/id_map.go
index 28cbe159d..955b6d40b 100644
--- a/pkg/sentry/kernel/auth/id_map.go
+++ b/pkg/sentry/kernel/auth/id_map.go
@@ -17,6 +17,7 @@ package auth
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -110,7 +111,7 @@ func (ns *UserNamespace) SetUIDMap(ctx context.Context, entries []IDMapEntry) er
}
// "At least one line must be written to the file."
if len(entries) == 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// """
// In order for a process to write to the /proc/[pid]/uid_map
@@ -170,11 +171,11 @@ func (ns *UserNamespace) trySetUIDMap(entries []IDMapEntry) error {
// checks for NoID.
lastID := e.FirstID + e.Length
if lastID <= e.FirstID {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
lastParentID := e.FirstParentID + e.Length
if lastParentID <= e.FirstParentID {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// "3. The mapped user IDs (group IDs) must in turn have a mapping in
// the parent user namespace."
@@ -186,10 +187,10 @@ func (ns *UserNamespace) trySetUIDMap(entries []IDMapEntry) error {
}
// If either of these Adds fail, we have an overlapping range.
if !ns.uidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if !ns.uidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
return nil
@@ -205,7 +206,7 @@ func (ns *UserNamespace) SetGIDMap(ctx context.Context, entries []IDMapEntry) er
return syserror.EPERM
}
if len(entries) == 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if !c.HasCapabilityIn(linux.CAP_SETGID, ns) {
return syserror.EPERM
@@ -239,20 +240,20 @@ func (ns *UserNamespace) trySetGIDMap(entries []IDMapEntry) error {
for _, e := range entries {
lastID := e.FirstID + e.Length
if lastID <= e.FirstID {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
lastParentID := e.FirstParentID + e.Length
if lastParentID <= e.FirstParentID {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if !ns.parent.allIDsMapped(&ns.parent.gidMapToParent, e.FirstParentID, lastParentID) {
return syserror.EPERM
}
if !ns.gidMapFromParent.Add(idMapRange{e.FirstParentID, lastParentID}, e.FirstID) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if !ns.gidMapToParent.Add(idMapRange{e.FirstID, lastID}, e.FirstParentID) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
return nil
diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go
index 0fbf27f64..c93ef6ac1 100644
--- a/pkg/sentry/kernel/cgroup.go
+++ b/pkg/sentry/kernel/cgroup.go
@@ -181,7 +181,23 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files
for _, h := range r.hierarchies {
if h.match(ctypes) {
- h.fs.IncRef()
+ if !h.fs.TryIncRef() {
+ // Racing with filesystem destruction, namely h.fs.Release.
+ // Since we hold r.mu, we know the hierarchy hasn't been
+ // unregistered yet, but its associated filesystem is tearing
+ // down.
+ //
+ // If we simply indicate the hierarchy wasn't found without
+ // cleaning up the registry, the caller can race with the
+ // unregister and find itself temporarily unable to create a new
+ // hierarchy with a subset of the relevant controllers.
+ //
+ // To keep the result of FindHierarchy consistent with the
+ // uniqueness of controllers enforced by Register, drop the
+ // dying hierarchy now. The eventual unregister by the FS
+ // teardown will become a no-op.
+ return nil
+ }
return h.fs
}
}
@@ -230,12 +246,17 @@ func (r *CgroupRegistry) Register(cs []CgroupController, fs cgroupFS) error {
return nil
}
-// Unregister removes a previously registered hierarchy from the registry. If
-// the controller was not previously registered, Unregister is a no-op.
+// Unregister removes a previously registered hierarchy from the registry. If no
+// such hierarchy is registered, Unregister is a no-op.
func (r *CgroupRegistry) Unregister(hid uint32) {
r.mu.Lock()
- defer r.mu.Unlock()
+ r.unregisterLocked(hid)
+ r.mu.Unlock()
+}
+// Precondition: Caller must hold r.mu.
+// +checklocks:r.mu
+func (r *CgroupRegistry) unregisterLocked(hid uint32) {
if h, ok := r.hierarchies[hid]; ok {
for name, _ := range h.controllers {
delete(r.controllers, name)
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index f855f038b..6b2dd09da 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -8,13 +8,12 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/arch",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index dbbbaeeb0..473987a79 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -17,13 +17,12 @@ package fasync
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -125,9 +124,9 @@ func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) {
if !permCheck {
return
}
- signalInfo := &arch.SignalInfo{
+ signalInfo := &linux.SignalInfo{
Signo: int32(linux.SIGIO),
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
if a.signal != 0 {
signalInfo.Signo = int32(a.signal)
@@ -249,7 +248,7 @@ func (a *FileAsync) Signal() linux.Signal {
// to send SIGIO.
func (a *FileAsync) SetSignal(signal linux.Signal) error {
if signal != 0 && !signal.IsValid() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
a.mu.Lock()
defer a.mu.Unlock()
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 62777faa8..8786a70b5 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -23,12 +23,12 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// FDFlags define flags for an individual descriptor.
@@ -156,7 +156,7 @@ func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) {
// Release any POSIX lock possibly held by the FDTable.
if file.SupportsLocks() {
err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF})
- if err != nil && err != syserror.ENOLCK {
+ if err != nil && !linuxerr.Equals(linuxerr.ENOLCK, err) {
panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
}
}
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index a75686cf3..0606d32a8 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_bucket_unsafe.go",
package = "futex",
suffix = "Bucket",
- template = "//pkg/sync:generic_atomicptr",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
types = {
"Value": "bucket",
},
@@ -37,6 +37,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/sentry/memmap",
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 0427cf3f4..5c64ce11e 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -20,6 +20,7 @@ package futex
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
@@ -332,7 +333,7 @@ func getKey(t Target, addr hostarch.Addr, private bool) (Key, error) {
// Ensure the address is aligned.
// It must be a DWORD boundary.
if addr&0x3 != 0 {
- return Key{}, syserror.EINVAL
+ return Key{}, linuxerr.EINVAL
}
if private {
return Key{Kind: KindPrivate, Offset: uint64(addr)}, nil
@@ -790,7 +791,7 @@ func (m *Manager) unlockPILocked(t Target, addr hostarch.Addr, tid uint32, b *bu
return err
}
if prev != cur {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
b.wakeWaiterLocked(next)
diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go
index 4b943106b..941cc373f 100644
--- a/pkg/sentry/kernel/kcov.go
+++ b/pkg/sentry/kernel/kcov.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/coverage"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -131,13 +132,13 @@ func (kcov *Kcov) InitTrace(size uint64) error {
// To simplify all the logic around mapping, we require that the length of the
// shared region is a multiple of the system page size.
if (8*size)&(hostarch.PageSize-1) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// We need space for at least two uint64s to hold current position and a
// single PC.
if size < 2 || size > kcovAreaSizeMax {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
kcov.size = size
@@ -157,7 +158,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error {
// KCOV_ENABLE must be preceded by KCOV_INIT_TRACE and an mmap call.
if kcov.mode != linux.KCOV_MODE_INIT || kcov.mappable == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
switch traceKind {
@@ -167,7 +168,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error {
// We do not support KCOV_MODE_TRACE_CMP.
return syserror.ENOTSUP
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if kcov.owningTask != nil && kcov.owningTask != t {
@@ -195,7 +196,7 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error {
}
if t != kcov.owningTask {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
kcov.mode = linux.KCOV_MODE_INIT
kcov.owningTask = nil
@@ -237,7 +238,7 @@ func (kcov *Kcov) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
defer kcov.mu.Unlock()
if kcov.mode != linux.KCOV_MODE_INIT {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if kcov.mappable == nil {
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index e6e9da898..352c36ba9 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -143,12 +143,6 @@ type Kernel struct {
// to CreateProcess, and is protected by extMu.
globalInit *ThreadGroup
- // realtimeClock is a ktime.Clock based on timekeeper's Realtime.
- realtimeClock *timekeeperClock
-
- // monotonicClock is a ktime.Clock based on timekeeper's Monotonic.
- monotonicClock *timekeeperClock
-
// syslog is the kernel log.
syslog syslog
@@ -354,19 +348,19 @@ type InitKernelArgs struct {
// before calling Init.
func (k *Kernel) Init(args InitKernelArgs) error {
if args.FeatureSet == nil {
- return fmt.Errorf("FeatureSet is nil")
+ return fmt.Errorf("args.FeatureSet is nil")
}
if args.Timekeeper == nil {
- return fmt.Errorf("Timekeeper is nil")
+ return fmt.Errorf("args.Timekeeper is nil")
}
if args.Timekeeper.clocks == nil {
return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()")
}
if args.RootUserNamespace == nil {
- return fmt.Errorf("RootUserNamespace is nil")
+ return fmt.Errorf("args.RootUserNamespace is nil")
}
if args.ApplicationCores == 0 {
- return fmt.Errorf("ApplicationCores is 0")
+ return fmt.Errorf("args.ApplicationCores is 0")
}
k.featureSet = args.FeatureSet
@@ -395,8 +389,6 @@ func (k *Kernel) Init(args InitKernelArgs) error {
}
k.extraAuxv = args.ExtraAuxv
k.vdso = args.Vdso
- k.realtimeClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Realtime}
- k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
k.ptraceExceptions = make(map[*Task]*Task)
@@ -529,6 +521,8 @@ func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error {
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
+ // Save the timekeeper's state.
+
// Save the kernel state.
kernelStart := time.Now()
stats, err := state.Save(ctx, w, k)
@@ -654,12 +648,12 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
defer k.tasks.mu.RUnlock()
for t := range k.tasks.Root.tids {
// We can skip locking Task.mu here since the kernel is paused.
- if mm := t.image.MemoryManager; mm != nil {
- if _, ok := invalidated[mm]; !ok {
- if err := mm.InvalidateUnsavable(ctx); err != nil {
+ if memMgr := t.image.MemoryManager; memMgr != nil {
+ if _, ok := invalidated[memMgr]; !ok {
+ if err := memMgr.InvalidateUnsavable(ctx); err != nil {
return err
}
- invalidated[mm] = struct{}{}
+ invalidated[memMgr] = struct{}{}
}
}
// I really wish we just had a sync.Map of all MMs...
@@ -673,7 +667,7 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
}
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
+func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, timeReady chan struct{}, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
loadStart := time.Now()
initAppCores := k.applicationCores
@@ -720,6 +714,11 @@ func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, cl
log.Infof("Overall load took [%s]", time.Since(loadStart))
k.Timekeeper().SetClocks(clocks)
+
+ if timeReady != nil {
+ close(timeReady)
+ }
+
if net != nil {
net.Resume()
}
@@ -1101,7 +1100,7 @@ func (k *Kernel) Start() error {
}
k.started = true
- k.cpuClockTicker = ktime.NewTimer(k.monotonicClock, newKernelCPUClockTicker(k))
+ k.cpuClockTicker = ktime.NewTimer(k.timekeeper.monotonicClock, newKernelCPUClockTicker(k))
k.cpuClockTicker.Swap(ktime.Setting{
Enabled: true,
Period: linux.ClockTick,
@@ -1256,7 +1255,7 @@ func (k *Kernel) incRunningTasks() {
// These cause very different value of cpuClock. But again, since
// nothing was running while the ticker was disabled, those differences
// don't matter.
- setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now())
+ setting, exp := k.cpuClockTickerSetting.At(k.timekeeper.monotonicClock.Now())
if exp > 0 {
atomic.AddUint64(&k.cpuClock, exp)
}
@@ -1339,7 +1338,7 @@ func (k *Kernel) Unpause() {
// context is used only for debugging to describe how the signal was received.
//
// Preconditions: Kernel must have an init process.
-func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
+func (k *Kernel) SendExternalSignal(info *linux.SignalInfo, context string) {
k.extMu.Lock()
defer k.extMu.Unlock()
k.sendExternalSignal(info, context)
@@ -1347,7 +1346,7 @@ func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
// SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup.
// This function doesn't skip signals like SendExternalSignal does.
-func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error {
+func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *linux.SignalInfo) error {
k.extMu.Lock()
defer k.extMu.Unlock()
return tg.SendSignal(info)
@@ -1355,7 +1354,7 @@ func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.Signa
// SendContainerSignal sends the given signal to all processes inside the
// namespace that match the given container ID.
-func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
+func (k *Kernel) SendContainerSignal(cid string, info *linux.SignalInfo) error {
k.extMu.Lock()
defer k.extMu.Unlock()
k.tasks.mu.RLock()
@@ -1466,12 +1465,12 @@ func (k *Kernel) ApplicationCores() uint {
// RealtimeClock returns the application CLOCK_REALTIME clock.
func (k *Kernel) RealtimeClock() ktime.Clock {
- return k.realtimeClock
+ return k.timekeeper.realtimeClock
}
// MonotonicClock returns the application CLOCK_MONOTONIC clock.
func (k *Kernel) MonotonicClock() ktime.Clock {
- return k.monotonicClock
+ return k.timekeeper.monotonicClock
}
// CPUClockNow returns the current value of k.cpuClock.
@@ -1551,31 +1550,6 @@ func (k *Kernel) SetSaveError(err error) {
}
}
-var _ tcpip.Clock = (*Kernel)(nil)
-
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (k *Kernel) NowNanoseconds() int64 {
- now, err := k.timekeeper.GetTime(sentrytime.Realtime)
- if err != nil {
- panic("Kernel.NowNanoseconds: " + err.Error())
- }
- return now
-}
-
-// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (k *Kernel) NowMonotonic() int64 {
- now, err := k.timekeeper.GetTime(sentrytime.Monotonic)
- if err != nil {
- panic("Kernel.NowMonotonic: " + err.Error())
- }
- return now
-}
-
-// AfterFunc implements tcpip.Clock.AfterFunc.
-func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer {
- return ktime.TcpipAfterFunc(k.realtimeClock, d, f)
-}
-
// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
// LoadFrom.
func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
@@ -1783,7 +1757,7 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
})
t := TaskFromContext(ctx)
- k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{
+ _, _ = k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{
Tid: int32(t.ThreadID()),
Registers: t.Arch().StateData().Proto(),
})
@@ -1858,7 +1832,9 @@ func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) {
return
}
t.mu.Lock()
- t.enterCgroupLocked(root)
+ // A task can be in the cgroup if it has been created after the
+ // cgroup hierarchy was registered.
+ t.enterCgroupIfNotYetLocked(root)
t.mu.Unlock()
})
k.tasks.mu.RUnlock()
@@ -1874,7 +1850,7 @@ func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) {
return
}
t.mu.Lock()
- for cg, _ := range t.cgroups {
+ for cg := range t.cgroups {
if cg.HierarchyID() == hid {
t.leaveCgroupLocked(cg)
}
diff --git a/pkg/sentry/kernel/pending_signals.go b/pkg/sentry/kernel/pending_signals.go
index 77a35b788..af455c434 100644
--- a/pkg/sentry/kernel/pending_signals.go
+++ b/pkg/sentry/kernel/pending_signals.go
@@ -17,7 +17,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
- "gvisor.dev/gvisor/pkg/sentry/arch"
)
const (
@@ -65,7 +64,7 @@ type pendingSignalQueue struct {
type pendingSignal struct {
// pendingSignalEntry links into a pendingSignalList.
pendingSignalEntry
- *arch.SignalInfo
+ *linux.SignalInfo
// If timer is not nil, it is the IntervalTimer which sent this signal.
timer *IntervalTimer
@@ -75,7 +74,7 @@ type pendingSignal struct {
// on failure (if the given signal's queue is full).
//
// Preconditions: info represents a valid signal.
-func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bool {
+func (p *pendingSignals) enqueue(info *linux.SignalInfo, timer *IntervalTimer) bool {
sig := linux.Signal(info.Signo)
q := &p.signals[sig.Index()]
if sig.IsStandard() {
@@ -93,7 +92,7 @@ func (p *pendingSignals) enqueue(info *arch.SignalInfo, timer *IntervalTimer) bo
// dequeue dequeues and returns any pending signal not masked by mask. If no
// unmasked signals are pending, dequeue returns nil.
-func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo {
+func (p *pendingSignals) dequeue(mask linux.SignalSet) *linux.SignalInfo {
// "Real-time signals are delivered in a guaranteed order. Multiple
// real-time signals of the same type are delivered in the order they were
// sent. If different real-time signals are sent to a process, they are
@@ -111,7 +110,7 @@ func (p *pendingSignals) dequeue(mask linux.SignalSet) *arch.SignalInfo {
return p.dequeueSpecific(linux.Signal(lowestPendingUnblockedBit + 1))
}
-func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *arch.SignalInfo {
+func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *linux.SignalInfo {
q := &p.signals[sig.Index()]
ps := q.pendingSignalList.Front()
if ps == nil {
diff --git a/pkg/sentry/kernel/pending_signals_state.go b/pkg/sentry/kernel/pending_signals_state.go
index ca8b4e164..e77f1a254 100644
--- a/pkg/sentry/kernel/pending_signals_state.go
+++ b/pkg/sentry/kernel/pending_signals_state.go
@@ -14,13 +14,11 @@
package kernel
-import (
- "gvisor.dev/gvisor/pkg/sentry/arch"
-)
+import "gvisor.dev/gvisor/pkg/abi/linux"
// +stateify savable
type savedPendingSignal struct {
- si *arch.SignalInfo
+ si *linux.SignalInfo
timer *IntervalTimer
}
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 34c617b08..94ebac7c5 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/marshal/primitive",
"//pkg/safemem",
@@ -47,6 +48,7 @@ go_test(
library = ":pipe",
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index 6497dc4ba..2321d26dc 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -17,6 +17,7 @@ package pipe
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sync"
@@ -130,7 +131,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
return rw, nil
default:
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index d6fb0fdb8..d25cf658e 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -258,7 +259,7 @@ func TestNonblockingWriteOpenFileNoReaders(t *testing.T) {
ctx := newSleeperContext(t)
f := NewInodeOperations(ctx, perms, newNamedPipe(t))
- if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); err != syserror.ENXIO {
+ if _, err := testOpen(ctx, t, f, fs.FileFlags{Write: true, NonBlocking: true}, nil); !linuxerr.Equals(linuxerr.ENXIO, err) {
t.Fatalf("Nonblocking open for write failed unexpected error %v.", err)
}
}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 06769931a..979ea10bf 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -22,6 +22,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -428,7 +429,7 @@ func (p *Pipe) FifoSize(context.Context, *fs.File) (int64, error) {
// SetFifoSize implements fs.FifoSizer.SetFifoSize.
func (p *Pipe) SetFifoSize(size int64) (int64, error) {
if size < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if size < MinimumPipeSize {
size = MinimumPipeSize // Per spec.
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 2d89b9ccd..3fa5d1d2f 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -86,6 +86,12 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error)
if n > 0 {
p.Notify(waiter.ReadableEvents)
}
+ if err == unix.EPIPE {
+ // If we are returning EPIPE send SIGPIPE to the task.
+ if sendSig := linux.SignalNoInfoFuncFromContext(ctx); sendSig != nil {
+ sendSig(linux.SIGPIPE)
+ }
+ }
return n, err
}
@@ -129,7 +135,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume
v = math.MaxInt32 // Silently truncate.
}
// Copy result to userspace.
- iocc := primitive.IOCopyContext{
+ iocc := usermem.IOCopyContext{
IO: io,
Ctx: ctx,
Opts: usermem.IOOpts{
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 95b948edb..623375417 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -17,6 +17,7 @@ package pipe
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -90,7 +91,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
readable := vfs.MayReadFileWithOpenFlags(statusFlags)
writable := vfs.MayWriteFileWithOpenFlags(statusFlags)
if !readable && !writable {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
fd, err := vp.newFD(mnt, vfsd, statusFlags, locks)
@@ -415,7 +416,7 @@ func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
// Preconditions: count > 0.
func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) {
if dst.pipe == src.pipe {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
lockTwoPipes(dst.pipe, src.pipe)
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
index 2e861a5a8..049cc07df 100644
--- a/pkg/sentry/kernel/posixtimer.go
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -18,7 +18,7 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -97,7 +97,7 @@ func (it *IntervalTimer) ResumeTimer() {
}
// Preconditions: it.target's signal mutex must be locked.
-func (it *IntervalTimer) updateDequeuedSignalLocked(si *arch.SignalInfo) {
+func (it *IntervalTimer) updateDequeuedSignalLocked(si *linux.SignalInfo) {
it.sigpending = false
if it.sigorphan {
return
@@ -138,9 +138,9 @@ func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Settin
it.sigpending = true
it.sigorphan = false
it.overrunCur += exp - 1
- si := &arch.SignalInfo{
+ si := &linux.SignalInfo{
Signo: int32(it.signo),
- Code: arch.SignalInfoTimer,
+ Code: linux.SI_TIMER,
}
si.SetTimerID(it.id)
si.SetSigval(it.sigval)
@@ -215,16 +215,16 @@ func (t *Task) IntervalTimerCreate(c ktime.Clock, sigev *linux.Sigevent) (linux.
target, ok := t.tg.pidns.tasks[ThreadID(sigev.Tid)]
t.tg.pidns.owner.mu.RUnlock()
if !ok || target.tg != t.tg {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
it.target = target
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if sigev.Notify != linux.SIGEV_NONE {
it.signo = linux.Signal(sigev.Signo)
if !it.signo.IsValid() {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
it.timer = ktime.NewTimer(c, it)
@@ -239,7 +239,7 @@ func (t *Task) IntervalTimerDelete(id linux.TimerID) error {
defer t.tg.timerMu.Unlock()
it := t.tg.timers[id]
if it == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
delete(t.tg.timers, id)
it.DestroyTimer()
@@ -252,7 +252,7 @@ func (t *Task) IntervalTimerSettime(id linux.TimerID, its linux.Itimerspec, abs
defer t.tg.timerMu.Unlock()
it := t.tg.timers[id]
if it == nil {
- return linux.Itimerspec{}, syserror.EINVAL
+ return linux.Itimerspec{}, linuxerr.EINVAL
}
newS, err := ktime.SettingFromItimerspec(its, abs, it.timer.Clock())
@@ -270,7 +270,7 @@ func (t *Task) IntervalTimerGettime(id linux.TimerID) (linux.Itimerspec, error)
defer t.tg.timerMu.Unlock()
it := t.tg.timers[id]
if it == nil {
- return linux.Itimerspec{}, syserror.EINVAL
+ return linux.Itimerspec{}, linuxerr.EINVAL
}
tm, s := it.timer.Get()
@@ -286,7 +286,7 @@ func (t *Task) IntervalTimerGetoverrun(id linux.TimerID) (int32, error) {
defer t.tg.timerMu.Unlock()
it := t.tg.timers[id]
if it == nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// By timer_create(2) invariant, either it.target == nil (in which case
// it.overrunLast is immutably 0) or t.tg == it.target.tg; and the fact
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 57c7659e7..1c6100efe 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -19,9 +19,9 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -295,7 +295,7 @@ func (t *Task) isYAMADescendantOfLocked(ancestor *Task) bool {
// Precondition: the TaskSet mutex must be locked (for reading or writing).
func (t *Task) hasYAMAExceptionForLocked(tracer *Task) bool {
- allowed, ok := t.k.ptraceExceptions[t]
+ allowed, ok := t.k.ptraceExceptions[t.tg.leader]
if !ok {
return false
}
@@ -394,7 +394,7 @@ func (t *Task) ptraceTrapLocked(code int32) {
t.trapStopPending = false
t.tg.signalHandlers.mu.Unlock()
t.ptraceCode = code
- t.ptraceSiginfo = &arch.SignalInfo{
+ t.ptraceSiginfo = &linux.SignalInfo{
Signo: int32(linux.SIGTRAP),
Code: code,
}
@@ -402,7 +402,7 @@ func (t *Task) ptraceTrapLocked(code int32) {
t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
if t.beginPtraceStopLocked() {
tracer := t.Tracer()
- tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
+ tracer.signalStop(t, linux.CLD_TRAPPED, int32(linux.SIGTRAP))
tracer.tg.eventQueue.Notify(EventTraceeStop)
}
}
@@ -542,9 +542,9 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error {
// "Unlike PTRACE_ATTACH, PTRACE_SEIZE does not stop the process." -
// ptrace(2)
if !seize {
- target.sendSignalLocked(&arch.SignalInfo{
+ target.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGSTOP),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}, false /* group */)
}
// Undocumented Linux feature: If the tracee is already group-stopped (and
@@ -586,7 +586,7 @@ func (t *Task) exitPtrace() {
for target := range t.ptraceTracees {
if target.ptraceOpts.ExitKill {
target.tg.signalHandlers.mu.Lock()
- target.sendSignalLocked(&arch.SignalInfo{
+ target.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGKILL),
}, false /* group */)
target.tg.signalHandlers.mu.Unlock()
@@ -652,7 +652,7 @@ func (t *Task) forgetTracerLocked() {
// Preconditions:
// * The signal mutex must be locked.
// * The caller must be running on the task goroutine.
-func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool {
+func (t *Task) ptraceSignalLocked(info *linux.SignalInfo) bool {
if linux.Signal(info.Signo) == linux.SIGKILL {
return false
}
@@ -678,7 +678,7 @@ func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool {
t.ptraceSiginfo = info
t.Debugf("Entering signal-delivery-stop for signal %d", info.Signo)
if t.beginPtraceStopLocked() {
- tracer.signalStop(t, arch.CLD_TRAPPED, info.Signo)
+ tracer.signalStop(t, linux.CLD_TRAPPED, info.Signo)
tracer.tg.eventQueue.Notify(EventTraceeStop)
}
return true
@@ -829,7 +829,7 @@ func (t *Task) ptraceClone(kind ptraceCloneKind, child *Task, opts *CloneOptions
if child.ptraceSeized {
child.trapStopPending = true
} else {
- child.pendingSignals.enqueue(&arch.SignalInfo{
+ child.pendingSignals.enqueue(&linux.SignalInfo{
Signo: int32(linux.SIGSTOP),
}, nil)
}
@@ -893,9 +893,9 @@ func (t *Task) ptraceExec(oldTID ThreadID) {
}
t.tg.signalHandlers.mu.Lock()
defer t.tg.signalHandlers.mu.Unlock()
- t.sendSignalLocked(&arch.SignalInfo{
+ t.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGTRAP),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}, false /* group */)
}
@@ -995,7 +995,7 @@ func (t *Task) ptraceSetOptionsLocked(opts uintptr) error {
linux.PTRACE_O_TRACEVFORK |
linux.PTRACE_O_TRACEVFORKDONE)
if opts&^valid != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
t.ptraceOpts = ptraceOptions{
ExitKill: opts&linux.PTRACE_O_EXITKILL != 0,
@@ -1222,27 +1222,27 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
if target.ptraceSiginfo == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
_, err := target.ptraceSiginfo.CopyOut(t, data)
return err
case linux.PTRACE_SETSIGINFO:
- var info arch.SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(t, data); err != nil {
return err
}
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
if target.ptraceSiginfo == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
target.ptraceSiginfo = &info
return nil
case linux.PTRACE_GETSIGMASK:
if addr != linux.SignalSetSize {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
mask := target.SignalMask()
_, err := mask.CopyOut(t, data)
@@ -1250,7 +1250,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error {
case linux.PTRACE_SETSIGMASK:
if addr != linux.SignalSetSize {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
var mask linux.SignalSet
if _, err := mask.CopyIn(t, data); err != nil {
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 4bc5bca44..2344565cd 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/syserror"
@@ -59,20 +60,20 @@ func (t *Task) RSeqAvailable() bool {
func (t *Task) SetRSeq(addr hostarch.Addr, length, signature uint32) error {
if t.rseqAddr != 0 {
if t.rseqAddr != addr {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if t.rseqSignature != signature {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
return syserror.EBUSY
}
// rseq must be aligned and correctly sized.
if addr&(linux.AlignOfRSeq-1) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if length != linux.SizeOfRSeq {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok {
return syserror.EFAULT
@@ -103,13 +104,13 @@ func (t *Task) SetRSeq(addr hostarch.Addr, length, signature uint32) error {
// Preconditions: The caller must be running on the task goroutine.
func (t *Task) ClearRSeq(addr hostarch.Addr, length, signature uint32) error {
if t.rseqAddr == 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if t.rseqAddr != addr {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if length != linux.SizeOfRSeq {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if t.rseqSignature != signature {
return syserror.EPERM
@@ -152,10 +153,10 @@ func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error {
return nil
}
if r.CriticalSection.Start >= r.CriticalSection.End {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if r.CriticalSection.Contains(r.Restart) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// TODO(jamieliu): check that r.CriticalSection and r.Restart are in
// the application address range, for consistency with Linux.
@@ -187,7 +188,7 @@ func (t *Task) SetOldRSeqCPUAddr(addr hostarch.Addr) error {
// unfortunate, but unlikely in a correct program.
if err := t.rseqUpdateCPU(); err != nil {
t.oldRSeqCPUAddr = 0
- return syserror.EINVAL // yes, EINVAL, not err or EFAULT
+ return linuxerr.EINVAL // yes, EINVAL, not err or EFAULT
}
return nil
}
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index a95e174a2..54ca43c2e 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -39,11 +39,11 @@ func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input {
}
}
-func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *arch.SignalInfo {
- si := &arch.SignalInfo{
+func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *linux.SignalInfo {
+ si := &linux.SignalInfo{
Signo: int32(linux.SIGSYS),
Errno: errno,
- Code: arch.SYS_SECCOMP,
+ Code: linux.SYS_SECCOMP,
}
si.SetCallAddr(uint64(ip))
si.SetSyscall(sysno)
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 65e5427c1..a787c00a8 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -25,6 +25,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index fe2ab1662..2dbc8353a 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -35,10 +36,10 @@ const (
// Maximum number of semaphore sets.
setsMax = linux.SEMMNI
- // Maximum number of semaphroes in a semaphore set.
+ // Maximum number of semaphores in a semaphore set.
semsMax = linux.SEMMSL
- // Maximum number of semaphores in all semaphroe sets.
+ // Maximum number of semaphores in all semaphore sets.
semsTotalMax = linux.SEMMNS
)
@@ -127,7 +128,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
// exists.
func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
if nsems < 0 || nsems > semsMax {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
r.mu.Lock()
@@ -147,7 +148,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
// Validate parameters.
if nsems > int32(set.Size()) {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
if create && exclusive {
return nil, syserror.EEXIST
@@ -163,7 +164,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
// Zero is only valid if an existing set is found.
if nsems == 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
// Apply system limits.
@@ -171,10 +172,10 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
// Map semaphores and map indexes in a registry are of the same size,
// check map semaphores only here for the system limit.
if len(r.semaphores) >= setsMax {
- return nil, syserror.EINVAL
+ return nil, syserror.ENOSPC
}
if r.totalSems() > int(semsTotalMax-nsems) {
- return nil, syserror.EINVAL
+ return nil, syserror.ENOSPC
}
// Finally create a new set.
@@ -220,7 +221,7 @@ func (r *Registry) HighestIndex() int32 {
defer r.mu.Unlock()
// By default, highest used index is 0 even though
- // there is no semaphroe set.
+ // there is no semaphore set.
var highestIndex int32
for index := range r.indexes {
if index > highestIndex {
@@ -238,7 +239,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
set := r.semaphores[id]
if set == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
index, found := r.findIndexByID(id)
if !found {
@@ -702,7 +703,9 @@ func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool {
return s.checkCapability(creds)
}
-// destroy destroys the set. Caller must hold 's.mu'.
+// destroy destroys the set.
+//
+// Preconditions: Caller must hold 's.mu'.
func (s *Set) destroy() {
// Notify all waiters. They will fail on the next attempt to execute
// operations and return error.
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 0cd9e2533..973d708a3 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -16,7 +16,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -233,7 +232,7 @@ func (pg *ProcessGroup) Session() *Session {
// SendSignal sends a signal to all processes inside the process group. It is
// analagous to kernel/signal.c:kill_pgrp.
-func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
+func (pg *ProcessGroup) SendSignal(info *linux.SignalInfo) error {
tasks := pg.originator.TaskSet()
tasks.mu.RLock()
defer tasks.mu.RUnlock()
@@ -370,6 +369,11 @@ func (tg *ThreadGroup) CreateProcessGroup() error {
// Get the ID for this thread in the current namespace.
id := tg.pidns.tgids[tg]
+ // Check whether a process still exists or not.
+ if id == 0 {
+ return syserror.ESRCH
+ }
+
// Per above, check for a Session leader or existing group.
for s := tg.pidns.owner.sessions.Front(); s != nil; s = s.Next() {
if s.leader.pidns != tg.pidns {
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index 1c3c0794f..5b69333fe 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -28,6 +28,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index a73f1bdca..7a6e91004 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -38,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -145,7 +146,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
//
// Note that 'private' always implies the creation of a new segment
// whether IPC_CREAT is specified or not.
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
r.mu.Lock()
@@ -175,7 +176,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
if size > shm.size {
// "A segment for the given key exists, but size is greater than
// the size of that segment." - man shmget(2)
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
if create && exclusive {
@@ -200,7 +201,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
if val, ok := hostarch.Addr(size).RoundUp(); ok {
sizeAligned = uint64(val)
} else {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
if numPages := sizeAligned / hostarch.PageSize; r.totalPages+numPages > linux.SHMALL {
@@ -652,7 +653,7 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID))
gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID))
if !uid.Ok() || !gid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// User may only modify the lower 9 bits of the mode. All the other bits are
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index 2488ae7d5..e08474d25 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
)
@@ -36,7 +35,7 @@ const SignalPanic = linux.SIGUSR2
// context is used only for debugging to differentiate these cases.
//
// Preconditions: Kernel must have an init process.
-func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
+func (k *Kernel) sendExternalSignal(info *linux.SignalInfo, context string) {
switch linux.Signal(info.Signo) {
case linux.SIGURG:
// Sent by the Go 1.14+ runtime for asynchronous goroutine preemption.
@@ -60,18 +59,18 @@ func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
}
// SignalInfoPriv returns a SignalInfo equivalent to Linux's SEND_SIG_PRIV.
-func SignalInfoPriv(sig linux.Signal) *arch.SignalInfo {
- return &arch.SignalInfo{
+func SignalInfoPriv(sig linux.Signal) *linux.SignalInfo {
+ return &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
}
// SignalInfoNoInfo returns a SignalInfo equivalent to Linux's SEND_SIG_NOINFO.
-func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo {
- info := &arch.SignalInfo{
+func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *linux.SignalInfo {
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go
index 768fda220..147cc41bb 100644
--- a/pkg/sentry/kernel/signal_handlers.go
+++ b/pkg/sentry/kernel/signal_handlers.go
@@ -16,7 +16,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -30,14 +29,14 @@ type SignalHandlers struct {
mu sync.Mutex `state:"nosave"`
// actions is the action to be taken upon receiving each signal.
- actions map[linux.Signal]arch.SignalAct
+ actions map[linux.Signal]linux.SigAction
}
// NewSignalHandlers returns a new SignalHandlers specifying all default
// actions.
func NewSignalHandlers() *SignalHandlers {
return &SignalHandlers{
- actions: make(map[linux.Signal]arch.SignalAct),
+ actions: make(map[linux.Signal]linux.SigAction),
}
}
@@ -59,9 +58,9 @@ func (sh *SignalHandlers) CopyForExec() *SignalHandlers {
sh.mu.Lock()
defer sh.mu.Unlock()
for sig, act := range sh.actions {
- if act.Handler == arch.SignalActIgnore {
- sh2.actions[sig] = arch.SignalAct{
- Handler: arch.SignalActIgnore,
+ if act.Handler == linux.SIG_IGN {
+ sh2.actions[sig] = linux.SigAction{
+ Handler: linux.SIG_IGN,
}
}
}
@@ -73,15 +72,15 @@ func (sh *SignalHandlers) IsIgnored(sig linux.Signal) bool {
sh.mu.Lock()
defer sh.mu.Unlock()
sa, ok := sh.actions[sig]
- return ok && sa.Handler == arch.SignalActIgnore
+ return ok && sa.Handler == linux.SIG_IGN
}
// dequeueActionLocked returns the SignalAct that should be used to handle sig.
//
// Preconditions: sh.mu must be locked.
-func (sh *SignalHandlers) dequeueAction(sig linux.Signal) arch.SignalAct {
+func (sh *SignalHandlers) dequeueAction(sig linux.Signal) linux.SigAction {
act := sh.actions[sig]
- if act.IsResetHandler() {
+ if act.Flags&linux.SA_RESETHAND != 0 {
delete(sh.actions, sig)
}
return act
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 76d472292..1110ecca5 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index f58ec4194..47958e2d4 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -18,6 +18,7 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -64,7 +65,7 @@ func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
t := kernel.TaskFromContext(ctx)
if t == nil {
// No task context? Not valid.
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
// name matches fs/signalfd.c:signalfd4.
dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]")
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index be1371855..d211e4d82 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -21,8 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -33,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -151,7 +150,7 @@ type Task struct {
// which the SA_ONSTACK flag is set.
//
// signalStack is exclusive to the task goroutine.
- signalStack arch.SignalStack
+ signalStack linux.SignalStack
// signalQueue is a set of registered waiters for signal-related events.
//
@@ -395,7 +394,7 @@ type Task struct {
// ptraceSiginfo is analogous to Linux's task_struct::last_siginfo.
//
// ptraceSiginfo is protected by the TaskSet mutex.
- ptraceSiginfo *arch.SignalInfo
+ ptraceSiginfo *linux.SignalInfo
// ptraceEventMsg is the value set by PTRACE_EVENT stops and returned to
// the tracer by ptrace(PTRACE_GETEVENTMSG).
@@ -847,21 +846,19 @@ func (t *Task) OOMScoreAdj() int32 {
// value should be between -1000 and 1000 inclusive.
func (t *Task) SetOOMScoreAdj(adj int32) error {
if adj > 1000 || adj < -1000 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
atomic.StoreInt32(&t.tg.oomScoreAdj, adj)
return nil
}
-// UID returns t's uid.
-// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
-func (t *Task) UID() uint32 {
+// KUID returns t's kuid.
+func (t *Task) KUID() uint32 {
return uint32(t.Credentials().EffectiveKUID)
}
-// GID returns t's gid.
-// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
-func (t *Task) GID() uint32 {
+// KGID returns t's kgid.
+func (t *Task) KGID() uint32 {
return uint32(t.Credentials().EffectiveKGID)
}
diff --git a/pkg/sentry/kernel/task_acct.go b/pkg/sentry/kernel/task_acct.go
index e574997f7..dd364ae50 100644
--- a/pkg/sentry/kernel/task_acct.go
+++ b/pkg/sentry/kernel/task_acct.go
@@ -18,10 +18,10 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Getitimer implements getitimer(2).
@@ -44,7 +44,7 @@ func (t *Task) Getitimer(id int32) (linux.ItimerVal, error) {
s, _ = t.tg.itimerProfSetting.At(tm)
t.tg.signalHandlers.mu.Unlock()
default:
- return linux.ItimerVal{}, syserror.EINVAL
+ return linux.ItimerVal{}, linuxerr.EINVAL
}
val, iv := ktime.SpecFromSetting(tm, s)
return linux.ItimerVal{
@@ -105,7 +105,7 @@ func (t *Task) Setitimer(id int32, newitv linux.ItimerVal) (linux.ItimerVal, err
return linux.ItimerVal{}, err
}
default:
- return linux.ItimerVal{}, syserror.EINVAL
+ return linux.ItimerVal{}, linuxerr.EINVAL
}
oldval, oldiv := ktime.SpecFromSetting(tm, olds)
return linux.ItimerVal{
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index ecbe8f920..07533d982 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -19,6 +19,7 @@ import (
"runtime/trace"
"time"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -45,7 +46,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.
err := t.BlockWithDeadline(C, true, deadline)
// Timeout, explicitly return a remaining duration of 0.
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return 0, err
}
diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go
index 25d2504fa..7c138e80f 100644
--- a/pkg/sentry/kernel/task_cgroup.go
+++ b/pkg/sentry/kernel/task_cgroup.go
@@ -85,6 +85,14 @@ func (t *Task) enterCgroupLocked(c Cgroup) {
c.Enter(t)
}
+// +checklocks:t.mu
+func (t *Task) enterCgroupIfNotYetLocked(c Cgroup) {
+ if _, ok := t.cgroups[c]; ok {
+ return
+ }
+ t.enterCgroupLocked(c)
+}
+
// LeaveCgroups removes t out from all its cgroups.
func (t *Task) LeaveCgroups() {
t.mu.Lock()
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 405771f3f..76fb0e2cb 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
@@ -142,25 +143,25 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
// address, any set of signal handlers must refer to the same address
// space.
if !opts.NewSignalHandlers && opts.NewAddressSpace {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// In order for the behavior of thread-group-directed signals to be sane,
// all tasks in a thread group must share signal handlers.
if !opts.NewThreadGroup && opts.NewSignalHandlers {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// All tasks in a thread group must be in the same PID namespace.
if !opts.NewThreadGroup && (opts.NewPIDNamespace || t.childPIDNamespace != nil) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// The two different ways of specifying a new PID namespace are
// incompatible.
if opts.NewPIDNamespace && t.childPIDNamespace != nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Thread groups and FS contexts cannot span user namespaces.
if opts.NewUserNamespace && (!opts.NewThreadGroup || !opts.NewFSContext) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Pull task registers and FPU state, a cloned task will inherit the
@@ -463,14 +464,14 @@ func (t *Task) Unshare(opts *SharingOptions) error {
// sense that clone(2) allows a task to share signal handlers and address
// spaces with tasks in other thread groups.
if opts.NewAddressSpace || opts.NewSignalHandlers {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
creds := t.Credentials()
if opts.NewThreadGroup {
t.tg.signalHandlers.mu.Lock()
if t.tg.tasksCount != 1 {
t.tg.signalHandlers.mu.Unlock()
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
t.tg.signalHandlers.mu.Unlock()
// This isn't racy because we're the only living task, and therefore
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 70b0699dc..c82d9e82b 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -17,6 +17,7 @@ package kernel
import (
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -113,6 +114,10 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} {
return t.k.RealtimeClock()
case limits.CtxLimits:
return t.tg.limits
+ case linux.CtxSignalNoInfoFunc:
+ return func(sig linux.Signal) error {
+ return t.SendSignal(SignalInfoNoInfo(sig, t, t))
+ }
case pgalloc.CtxMemoryFile:
return t.k.mf
case pgalloc.CtxMemoryFileProvider:
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index d9897e802..cf8571262 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -66,7 +66,6 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -181,7 +180,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.tg.signalHandlers = t.tg.signalHandlers.CopyForExec()
t.endStopCond.L = &t.tg.signalHandlers.mu
// "Any alternate signal stack is not preserved (sigaltstack(2))." - execve(2)
- t.signalStack = arch.SignalStack{Flags: arch.SignalStackFlagDisable}
+ t.signalStack = linux.SignalStack{Flags: linux.SS_DISABLE}
// "The termination signal is reset to SIGCHLD (see clone(2))."
t.tg.terminationSignal = linux.SIGCHLD
// execed indicates that the process can no longer join a process group
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index b1af1a7ef..d115b8783 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -28,9 +28,9 @@ import (
"errors"
"fmt"
"strconv"
+ "strings"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
@@ -50,6 +50,23 @@ type ExitStatus struct {
Signo int
}
+func (es ExitStatus) String() string {
+ var b strings.Builder
+ if code := es.Code; code != 0 {
+ if b.Len() != 0 {
+ b.WriteByte(' ')
+ }
+ _, _ = fmt.Fprintf(&b, "Code=%d", code)
+ }
+ if signal := es.Signo; signal != 0 {
+ if b.Len() != 0 {
+ b.WriteByte(' ')
+ }
+ _, _ = fmt.Fprintf(&b, "Signal=%d", signal)
+ }
+ return b.String()
+}
+
// Signaled returns true if the ExitStatus indicates that the exiting task or
// thread group was killed by a signal.
func (es ExitStatus) Signaled() bool {
@@ -122,12 +139,12 @@ func (t *Task) killLocked() {
if t.stop != nil && t.stop.Killable() {
t.endInternalStopLocked()
}
- t.pendingSignals.enqueue(&arch.SignalInfo{
+ t.pendingSignals.enqueue(&linux.SignalInfo{
Signo: int32(linux.SIGKILL),
// Linux just sets SIGKILL in the pending signal bitmask without
// enqueueing an actual siginfo, such that
// kernel/signal.c:collect_signal() initializes si_code to SI_USER.
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}, nil)
t.interrupt()
}
@@ -332,7 +349,7 @@ func (t *Task) exitThreadGroup() bool {
// signalStop must be called with t's signal mutex unlocked.
t.tg.signalHandlers.mu.Unlock()
if notifyParent && t.tg.leader.parent != nil {
- t.tg.leader.parent.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.signalStop(t, linux.CLD_STOPPED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
}
return last
@@ -353,7 +370,7 @@ func (t *Task) exitChildren() {
continue
}
other.signalHandlers.mu.Lock()
- other.leader.sendSignalLocked(&arch.SignalInfo{
+ other.leader.sendSignalLocked(&linux.SignalInfo{
Signo: int32(linux.SIGKILL),
}, true /* group */)
other.signalHandlers.mu.Unlock()
@@ -368,9 +385,9 @@ func (t *Task) exitChildren() {
// wait for a parent to reap them.)
for c := range t.children {
if sig := c.ParentDeathSignal(); sig != 0 {
- siginfo := &arch.SignalInfo{
+ siginfo := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
siginfo.SetPID(int32(c.tg.pidns.tids[t]))
siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
@@ -652,10 +669,10 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) {
t.parent.tg.signalHandlers.mu.Lock()
if t.tg.terminationSignal == linux.SIGCHLD || fromPtraceDetach {
if act, ok := t.parent.tg.signalHandlers.actions[linux.SIGCHLD]; ok {
- if act.Handler == arch.SignalActIgnore {
+ if act.Handler == linux.SIG_IGN {
t.exitParentAcked = true
signalParent = false
- } else if act.Flags&arch.SignalFlagNoCldWait != 0 {
+ } else if act.Flags&linux.SA_NOCLDWAIT != 0 {
t.exitParentAcked = true
}
}
@@ -705,17 +722,17 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) {
}
// Preconditions: The TaskSet mutex must be locked.
-func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.SignalInfo {
- info := &arch.SignalInfo{
+func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *linux.SignalInfo {
+ info := &linux.SignalInfo{
Signo: int32(sig),
}
info.SetPID(int32(receiver.tg.pidns.tids[t]))
info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
if t.exitStatus.Signaled() {
- info.Code = arch.CLD_KILLED
+ info.Code = linux.CLD_KILLED
info.SetStatus(int32(t.exitStatus.Signo))
} else {
- info.Code = arch.CLD_EXITED
+ info.Code = linux.CLD_EXITED
info.SetStatus(int32(t.exitStatus.Code))
}
// TODO(b/72102453): Set utime, stime.
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index 0325967e4..29f154ebd 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -16,6 +16,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
@@ -47,7 +48,7 @@ func (t *Task) HasCapability(cp linux.Capability) bool {
func (t *Task) SetUID(uid auth.UID) error {
// setuid considers -1 to be invalid.
if !uid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
t.mu.Lock()
@@ -56,7 +57,7 @@ func (t *Task) SetUID(uid auth.UID) error {
creds := t.Credentials()
kuid := creds.UserNamespace.MapToKUID(uid)
if !kuid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// "setuid() sets the effective user ID of the calling process. If the
// effective UID of the caller is root (more precisely: if the caller has
@@ -87,14 +88,14 @@ func (t *Task) SetREUID(r, e auth.UID) error {
if r.Ok() {
newR = creds.UserNamespace.MapToKUID(r)
if !newR.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
newE := creds.EffectiveKUID
if e.Ok() {
newE = creds.UserNamespace.MapToKUID(e)
if !newE.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
if !creds.HasCapability(linux.CAP_SETUID) {
@@ -223,7 +224,7 @@ func (t *Task) setKUIDsUncheckedLocked(newR, newE, newS auth.KUID) {
// SetGID implements the semantics of setgid(2).
func (t *Task) SetGID(gid auth.GID) error {
if !gid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
t.mu.Lock()
@@ -232,7 +233,7 @@ func (t *Task) SetGID(gid auth.GID) error {
creds := t.Credentials()
kgid := creds.UserNamespace.MapToKGID(gid)
if !kgid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if creds.HasCapability(linux.CAP_SETGID) {
t.setKGIDsUncheckedLocked(kgid, kgid, kgid)
@@ -255,14 +256,14 @@ func (t *Task) SetREGID(r, e auth.GID) error {
if r.Ok() {
newR = creds.UserNamespace.MapToKGID(r)
if !newR.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
newE := creds.EffectiveKGID
if e.Ok() {
newE = creds.UserNamespace.MapToKGID(e)
if !newE.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
if !creds.HasCapability(linux.CAP_SETGID) {
@@ -349,7 +350,7 @@ func (t *Task) SetExtraGIDs(gids []auth.GID) error {
for i, gid := range gids {
kgid := creds.UserNamespace.MapToKGID(gid)
if !kgid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
kgids[i] = kgid
}
diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go
index bd5543d4e..c132c27ef 100644
--- a/pkg/sentry/kernel/task_image.go
+++ b/pkg/sentry/kernel/task_image.go
@@ -17,7 +17,7 @@ package kernel
import (
"fmt"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -27,7 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
)
-var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC)
+var errNoSyscalls = syserr.New("no syscall table found", errno.ENOEXEC)
// Auxmap contains miscellaneous data for the task.
type Auxmap map[string]interface{}
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index 9ba5f8d78..9d9fa76a6 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -23,12 +23,12 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/syserror"
)
// TaskGoroutineState is a coarse representation of the current execution
@@ -536,7 +536,7 @@ func (tg *ThreadGroup) updateCPUTimersEnabledLocked() {
// appropriate for /proc/[pid]/status.
func (t *Task) StateStatus() string {
switch s := t.TaskGoroutineSchedInfo().State; s {
- case TaskGoroutineNonexistent:
+ case TaskGoroutineNonexistent, TaskGoroutineRunningSys:
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
switch t.exitState {
@@ -546,16 +546,16 @@ func (t *Task) StateStatus() string {
return "X (dead)"
default:
// The task goroutine can't exit before passing through
- // runExitNotify, so this indicates that the task has been created,
- // but the task goroutine hasn't yet started. The Linux equivalent
- // is struct task_struct::state == TASK_NEW
+ // runExitNotify, so if s == TaskGoroutineNonexistent, the task has
+ // been created but the task goroutine hasn't yet started. The
+ // Linux equivalent is struct task_struct::state == TASK_NEW
// (kernel/fork.c:copy_process() =>
// kernel/sched/core.c:sched_fork()), but the TASK_NEW bit is
// masked out by TASK_REPORT for /proc/[pid]/status, leaving only
// TASK_RUNNING.
return "R (running)"
}
- case TaskGoroutineRunningSys, TaskGoroutineRunningApp:
+ case TaskGoroutineRunningApp:
return "R (running)"
case TaskGoroutineBlockedInterruptible:
return "S (sleeping)"
@@ -601,7 +601,7 @@ func (t *Task) SetCPUMask(mask sched.CPUSet) error {
// Ensure that at least 1 CPU is still allowed.
if mask.NumCPUs() == 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if t.k.useHostCores {
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index c2b9fc08f..f54c774cb 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -22,6 +22,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/eventchannel"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -86,7 +87,7 @@ var defaultActions = map[linux.Signal]SignalAction{
}
// computeAction figures out what to do given a signal number
-// and an arch.SignalAct. SIGSTOP always results in a SignalActionStop,
+// and an linux.SigAction. SIGSTOP always results in a SignalActionStop,
// and SIGKILL always results in a SignalActionTerm.
// Signal 0 is always ignored as many programs use it for various internal functions
// and don't expect it to do anything.
@@ -97,7 +98,7 @@ var defaultActions = map[linux.Signal]SignalAction{
// 0, the default action is taken;
// 1, the signal is ignored;
// anything else, the function returns SignalActionHandler.
-func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction {
+func computeAction(sig linux.Signal, act linux.SigAction) SignalAction {
switch sig {
case linux.SIGSTOP:
return SignalActionStop
@@ -108,9 +109,9 @@ func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction {
}
switch act.Handler {
- case arch.SignalActDefault:
+ case linux.SIG_DFL:
return defaultActions[sig]
- case arch.SignalActIgnore:
+ case linux.SIG_IGN:
return SignalActionIgnore
default:
return SignalActionHandler
@@ -127,7 +128,7 @@ var StopSignals = linux.MakeSignalSet(linux.SIGSTOP, linux.SIGTSTP, linux.SIGTTI
// If there are no pending unmasked signals, dequeueSignalLocked returns nil.
//
// Preconditions: t.tg.signalHandlers.mu must be locked.
-func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *arch.SignalInfo {
+func (t *Task) dequeueSignalLocked(mask linux.SignalSet) *linux.SignalInfo {
if info := t.pendingSignals.dequeue(mask); info != nil {
return info
}
@@ -155,7 +156,7 @@ func (t *Task) PendingSignals() linux.SignalSet {
}
// deliverSignal delivers the given signal and returns the following run state.
-func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunState {
+func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRunState {
sigact := computeAction(linux.Signal(info.Signo), act)
if t.haveSyscallReturn {
@@ -172,7 +173,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS
fallthrough
case sre == syserror.ERESTART_RESTARTBLOCK:
fallthrough
- case (sre == syserror.ERESTARTSYS && !act.IsRestart()):
+ case (sre == syserror.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0):
t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1)))
default:
@@ -236,7 +237,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS
// deliverSignalToHandler changes the task's userspace state to enter the given
// user-configured handler for the given signal.
-func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) error {
+func (t *Task) deliverSignalToHandler(info *linux.SignalInfo, act linux.SigAction) error {
// Signal delivery to an application handler interrupts restartable
// sequences.
t.rseqInterrupt()
@@ -248,8 +249,8 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
// N.B. This is a *copy* of the alternate stack that the user's signal
// handler expects to see in its ucontext (even if it's not in use).
alt := t.signalStack
- if act.IsOnStack() && alt.IsEnabled() {
- alt.SetOnStack()
+ if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() {
+ alt.Flags |= linux.SS_ONSTACK
if !alt.Contains(sp) {
sp = hostarch.Addr(alt.Top())
}
@@ -289,7 +290,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
// Add our signal mask.
newMask := t.signalMask | act.Mask
- if !act.IsNoDefer() {
+ if act.Flags&linux.SA_NODEFER == 0 {
newMask |= linux.SignalSetOf(linux.Signal(info.Signo))
}
t.SetSignalMask(newMask)
@@ -326,7 +327,7 @@ func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) {
// Preconditions:
// * The caller must be running on the task goroutine.
// * t.exitState < TaskExitZombie.
-func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) {
+func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*linux.SignalInfo, error) {
// set is the set of signals we're interested in; invert it to get the set
// of signals to block.
mask := ^(set &^ UnblockableSignals)
@@ -370,10 +371,10 @@ func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.S
// The following errors may be returned:
//
// syserror.ESRCH - The task has exited.
-// syserror.EINVAL - The signal is not valid.
+// linuxerr.EINVAL - The signal is not valid.
// syserror.EAGAIN - THe signal is realtime, and cannot be queued.
//
-func (t *Task) SendSignal(info *arch.SignalInfo) error {
+func (t *Task) SendSignal(info *linux.SignalInfo) error {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
t.tg.signalHandlers.mu.Lock()
@@ -382,7 +383,7 @@ func (t *Task) SendSignal(info *arch.SignalInfo) error {
}
// SendGroupSignal sends the given signal to t's thread group.
-func (t *Task) SendGroupSignal(info *arch.SignalInfo) error {
+func (t *Task) SendGroupSignal(info *linux.SignalInfo) error {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
t.tg.signalHandlers.mu.Lock()
@@ -392,7 +393,7 @@ func (t *Task) SendGroupSignal(info *arch.SignalInfo) error {
// SendSignal sends the given signal to tg, using tg's leader to determine if
// the signal is blocked.
-func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error {
+func (tg *ThreadGroup) SendSignal(info *linux.SignalInfo) error {
tg.pidns.owner.mu.RLock()
defer tg.pidns.owner.mu.RUnlock()
tg.signalHandlers.mu.Lock()
@@ -400,11 +401,11 @@ func (tg *ThreadGroup) SendSignal(info *arch.SignalInfo) error {
return tg.leader.sendSignalLocked(info, true /* group */)
}
-func (t *Task) sendSignalLocked(info *arch.SignalInfo, group bool) error {
+func (t *Task) sendSignalLocked(info *linux.SignalInfo, group bool) error {
return t.sendSignalTimerLocked(info, group, nil)
}
-func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *IntervalTimer) error {
+func (t *Task) sendSignalTimerLocked(info *linux.SignalInfo, group bool, timer *IntervalTimer) error {
if t.exitState == TaskExitDead {
return syserror.ESRCH
}
@@ -413,7 +414,7 @@ func (t *Task) sendSignalTimerLocked(info *arch.SignalInfo, group bool, timer *I
return nil
}
if !sig.IsValid() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Signal side effects apply even if the signal is ultimately discarded.
@@ -572,9 +573,9 @@ func (t *Task) forceSignal(sig linux.Signal, unconditional bool) {
func (t *Task) forceSignalLocked(sig linux.Signal, unconditional bool) {
blocked := linux.SignalSetOf(sig)&t.signalMask != 0
act := t.tg.signalHandlers.actions[sig]
- ignored := act.Handler == arch.SignalActIgnore
+ ignored := act.Handler == linux.SIG_IGN
if blocked || ignored || unconditional {
- act.Handler = arch.SignalActDefault
+ act.Handler = linux.SIG_DFL
t.tg.signalHandlers.actions[sig] = act
if blocked {
t.setSignalMaskLocked(t.signalMask &^ linux.SignalSetOf(sig))
@@ -641,17 +642,17 @@ func (t *Task) SetSavedSignalMask(mask linux.SignalSet) {
}
// SignalStack returns the task-private signal stack.
-func (t *Task) SignalStack() arch.SignalStack {
+func (t *Task) SignalStack() linux.SignalStack {
t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch())
alt := t.signalStack
if t.onSignalStack(alt) {
- alt.Flags |= arch.SignalStackFlagOnStack
+ alt.Flags |= linux.SS_ONSTACK
}
return alt
}
// onSignalStack returns true if the task is executing on the given signal stack.
-func (t *Task) onSignalStack(alt arch.SignalStack) bool {
+func (t *Task) onSignalStack(alt linux.SignalStack) bool {
sp := hostarch.Addr(t.Arch().Stack())
return alt.Contains(sp)
}
@@ -661,30 +662,30 @@ func (t *Task) onSignalStack(alt arch.SignalStack) bool {
// This value may not be changed if the task is currently executing on the
// signal stack, i.e. if t.onSignalStack returns true. In this case, this
// function will return false. Otherwise, true is returned.
-func (t *Task) SetSignalStack(alt arch.SignalStack) bool {
+func (t *Task) SetSignalStack(alt linux.SignalStack) bool {
// Check that we're not executing on the stack.
if t.onSignalStack(t.signalStack) {
return false
}
- if alt.Flags&arch.SignalStackFlagDisable != 0 {
+ if alt.Flags&linux.SS_DISABLE != 0 {
// Don't record anything beyond the flags.
- t.signalStack = arch.SignalStack{
- Flags: arch.SignalStackFlagDisable,
+ t.signalStack = linux.SignalStack{
+ Flags: linux.SS_DISABLE,
}
} else {
// Mask out irrelevant parts: only disable matters.
- alt.Flags &= arch.SignalStackFlagDisable
+ alt.Flags &= linux.SS_DISABLE
t.signalStack = alt
}
return true
}
-// SetSignalAct atomically sets the thread group's signal action for signal sig
+// SetSigAction atomically sets the thread group's signal action for signal sig
// to *actptr (if actptr is not nil) and returns the old signal action.
-func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (arch.SignalAct, error) {
+func (tg *ThreadGroup) SetSigAction(sig linux.Signal, actptr *linux.SigAction) (linux.SigAction, error) {
if !sig.IsValid() {
- return arch.SignalAct{}, syserror.EINVAL
+ return linux.SigAction{}, linuxerr.EINVAL
}
tg.pidns.owner.mu.RLock()
@@ -695,7 +696,7 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a
oldact := sh.actions[sig]
if actptr != nil {
if sig == linux.SIGKILL || sig == linux.SIGSTOP {
- return oldact, syserror.EINVAL
+ return oldact, linuxerr.EINVAL
}
act := *actptr
@@ -718,48 +719,6 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a
return oldact, nil
}
-// CopyOutSignalAct converts the given SignalAct into an architecture-specific
-// type and then copies it out to task memory.
-func (t *Task) CopyOutSignalAct(addr hostarch.Addr, s *arch.SignalAct) error {
- n := t.Arch().NewSignalAct()
- n.SerializeFrom(s)
- _, err := n.CopyOut(t, addr)
- return err
-}
-
-// CopyInSignalAct copies an architecture-specific sigaction type from task
-// memory and then converts it into a SignalAct.
-func (t *Task) CopyInSignalAct(addr hostarch.Addr) (arch.SignalAct, error) {
- n := t.Arch().NewSignalAct()
- var s arch.SignalAct
- if _, err := n.CopyIn(t, addr); err != nil {
- return s, err
- }
- n.DeserializeTo(&s)
- return s, nil
-}
-
-// CopyOutSignalStack converts the given SignalStack into an
-// architecture-specific type and then copies it out to task memory.
-func (t *Task) CopyOutSignalStack(addr hostarch.Addr, s *arch.SignalStack) error {
- n := t.Arch().NewSignalStack()
- n.SerializeFrom(s)
- _, err := n.CopyOut(t, addr)
- return err
-}
-
-// CopyInSignalStack copies an architecture-specific stack_t from task memory
-// and then converts it into a SignalStack.
-func (t *Task) CopyInSignalStack(addr hostarch.Addr) (arch.SignalStack, error) {
- n := t.Arch().NewSignalStack()
- var s arch.SignalStack
- if _, err := n.CopyIn(t, addr); err != nil {
- return s, err
- }
- n.DeserializeTo(&s)
- return s, nil
-}
-
// groupStop is a TaskStop placed on tasks that have received a stop signal
// (SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU). (The term "group-stop" originates from
// the ptrace man page.)
@@ -774,7 +733,7 @@ func (*groupStop) Killable() bool { return true }
// previously-dequeued stop signal.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) initiateGroupStop(info *arch.SignalInfo) {
+func (t *Task) initiateGroupStop(info *linux.SignalInfo) {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
t.tg.signalHandlers.mu.Lock()
@@ -909,8 +868,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) {
t.tg.signalHandlers.mu.Lock()
defer t.tg.signalHandlers.mu.Unlock()
act, ok := t.tg.signalHandlers.actions[linux.SIGCHLD]
- if !ok || (act.Handler != arch.SignalActIgnore && act.Flags&arch.SignalFlagNoCldStop == 0) {
- sigchld := &arch.SignalInfo{
+ if !ok || (act.Handler != linux.SIG_IGN && act.Flags&linux.SA_NOCLDSTOP == 0) {
+ sigchld := &linux.SignalInfo{
Signo: int32(linux.SIGCHLD),
Code: code,
}
@@ -955,14 +914,14 @@ func (*runInterrupt) execute(t *Task) taskRunState {
// notified its tracer accordingly. But it's consistent with
// Linux...
if intr {
- tracer.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ tracer.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig))
if !notifyParent {
tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop | EventChildGroupStop)
} else {
tracer.tg.eventQueue.Notify(EventGroupContinue | EventTraceeStop)
}
} else {
- tracer.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ tracer.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig))
tracer.tg.eventQueue.Notify(EventGroupContinue)
}
}
@@ -974,10 +933,10 @@ func (*runInterrupt) execute(t *Task) taskRunState {
// SIGCHLD is a standard signal, so the latter would always be
// dropped. Hence sending only the former is equivalent.
if intr {
- t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue | EventChildGroupStop)
} else {
- t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_CONTINUED, int32(sig))
+ t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_CONTINUED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventGroupContinue)
}
}
@@ -1018,7 +977,7 @@ func (*runInterrupt) execute(t *Task) taskRunState {
// without requiring an extra PTRACE_GETSIGINFO call." -
// "Group-stop", ptrace(2)
t.ptraceCode = int32(sig) | linux.PTRACE_EVENT_STOP<<8
- t.ptraceSiginfo = &arch.SignalInfo{
+ t.ptraceSiginfo = &linux.SignalInfo{
Signo: int32(sig),
Code: t.ptraceCode,
}
@@ -1029,7 +988,7 @@ func (*runInterrupt) execute(t *Task) taskRunState {
t.ptraceSiginfo = nil
}
if t.beginPtraceStopLocked() {
- tracer.signalStop(t, arch.CLD_STOPPED, int32(sig))
+ tracer.signalStop(t, linux.CLD_STOPPED, int32(sig))
// For consistency with Linux, if the parent and tracer are in the
// same thread group, deduplicate notification signals.
if notifyParent && tracer.tg == t.tg.leader.parent.tg {
@@ -1047,7 +1006,7 @@ func (*runInterrupt) execute(t *Task) taskRunState {
t.tg.signalHandlers.mu.Unlock()
}
if notifyParent {
- t.tg.leader.parent.signalStop(t.tg.leader, arch.CLD_STOPPED, int32(sig))
+ t.tg.leader.parent.signalStop(t.tg.leader, linux.CLD_STOPPED, int32(sig))
t.tg.leader.parent.tg.eventQueue.Notify(EventChildGroupStop)
}
t.tg.pidns.owner.mu.RUnlock()
@@ -1101,7 +1060,7 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
if sig != linux.Signal(info.Signo) {
info.Signo = int32(sig)
info.Errno = 0
- info.Code = arch.SignalInfoUser
+ info.Code = linux.SI_USER
// pid isn't a valid field for all signal numbers, but Linux
// doesn't care (kernel/signal.c:ptrace_signal()).
//
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 32031cd70..41fd2d471 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
@@ -131,7 +130,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
runState: (*runApp)(nil),
interruptChan: make(chan struct{}, 1),
signalMask: cfg.SignalMask,
- signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
+ signalStack: linux.SignalStack{Flags: linux.SS_DISABLE},
image: *image,
fsContext: cfg.FSContext,
fdTable: cfg.FDTable,
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index 601fc0d3a..409b712d8 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -22,6 +22,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/errors"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/metric"
@@ -357,7 +359,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle
t.Arch().SetReturn(uintptr(rval))
} else {
t.Debugf("vsyscall %d, caller %x: emulated syscall returned error: %v", sysno, t.Arch().Value(caller), err)
- if err == syserror.EFAULT {
+ if linuxerr.Equals(linuxerr.EFAULT, err) {
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
// A return is not emulated in this case.
@@ -379,6 +381,8 @@ func ExtractErrno(err error, sysno int) int {
return 0
case unix.Errno:
return int(err)
+ case *errors.Error:
+ return int(err.Errno())
case syserror.SyscallRestartErrno:
return int(err)
case *memmap.BusError:
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index fc6d9438a..7935d15a6 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
@@ -202,7 +203,7 @@ func (t *Task) CopyInIovecs(addr hostarch.Addr, numIovecs int) (hostarch.AddrRan
base := hostarch.Addr(hostarch.ByteOrder.Uint64(b[0:8]))
length := hostarch.ByteOrder.Uint64(b[8:16])
if length > math.MaxInt64 {
- return hostarch.AddrRangeSeq{}, syserror.EINVAL
+ return hostarch.AddrRangeSeq{}, linuxerr.EINVAL
}
ar, ok := t.MemoryManager().CheckIORange(base, int64(length))
if !ok {
@@ -270,7 +271,7 @@ func (t *Task) SingleIOSequence(addr hostarch.Addr, length int, opts usermem.IOO
// Preconditions: Same as Task.CopyInIovecs.
func (t *Task) IovecsIOSequence(addr hostarch.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) {
if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
- return usermem.IOSequence{}, syserror.EINVAL
+ return usermem.IOSequence{}, linuxerr.EINVAL
}
ars, err := t.CopyInIovecs(addr, iovcnt)
if err != nil {
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index b92e98fa1..8ae00c649 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -19,7 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -279,7 +279,7 @@ func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, s
limits: limits,
mounts: mntns,
}
- tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg})
+ tg.itimerRealTimer = ktime.NewTimer(k.timekeeper.monotonicClock, &itimerRealListener{tg: tg})
tg.timers = make(map[linux.TimerID]*IntervalTimer)
tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
return tg
@@ -358,7 +358,7 @@ func (tg *ThreadGroup) SetControllingTTY(tty *TTY, steal bool, isReadable bool)
// "The calling process must be a session leader and not have a
// controlling terminal already." - tty_ioctl(4)
if tg.processGroup.session.leader != tg || tg.tty != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(tg.leader)
@@ -446,10 +446,10 @@ func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error {
othertg.signalHandlers.mu.Lock()
othertg.tty = nil
if othertg.processGroup == tg.processGroup.session.foreground {
- if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
+ if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
lastErr = err
}
- if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
+ if err := othertg.leader.sendSignalLocked(&linux.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
lastErr = err
}
}
@@ -490,10 +490,10 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID)
tg.signalHandlers.mu.Lock()
defer tg.signalHandlers.mu.Unlock()
- // TODO(b/129283598): "If tcsetpgrp() is called by a member of a
- // background process group in its session, and the calling process is
- // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all
- // members of this background process group."
+ // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a
+ // background process group in its session, and the calling process is not
+ // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of
+ // this background process group."
// tty must be the controlling terminal.
if tg.tty != tty {
@@ -502,7 +502,7 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID)
// pgid must be positive.
if pgid < 0 {
- return -1, syserror.EINVAL
+ return -1, linuxerr.EINVAL
}
// pg must not be empty. Empty process groups are removed from their
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 2817aa3ba..e293d9a0f 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -13,8 +13,8 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index f61a8e164..191b92811 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -22,8 +22,8 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -322,7 +322,7 @@ func SettingFromSpec(value time.Duration, interval time.Duration, c Clock) (Sett
// interpreted as a time relative to now.
func SettingFromSpecAt(value time.Duration, interval time.Duration, now Time) (Setting, error) {
if value < 0 {
- return Setting{}, syserror.EINVAL
+ return Setting{}, linuxerr.EINVAL
}
if value == 0 {
return Setting{Period: interval}, nil
@@ -338,7 +338,7 @@ func SettingFromSpecAt(value time.Duration, interval time.Duration, now Time) (S
// interpreted as an absolute time.
func SettingFromAbsSpec(value Time, interval time.Duration) (Setting, error) {
if value.Before(ZeroTime) {
- return Setting{}, syserror.EINVAL
+ return Setting{}, linuxerr.EINVAL
}
if value.IsZero() {
return Setting{Period: interval}, nil
@@ -458,25 +458,6 @@ func NewTimer(clock Clock, listener TimerListener) *Timer {
return t
}
-// After waits for the duration to elapse according to clock and then sends a
-// notification on the returned channel. The timer is started immediately and
-// will fire exactly once. The second return value is the start time used with
-// the duration.
-//
-// Callers must call Timer.Destroy.
-func After(clock Clock, duration time.Duration) (*Timer, Time, <-chan struct{}) {
- notifier, tchan := NewChannelNotifier()
- t := NewTimer(clock, notifier)
- now := clock.Now()
-
- t.Swap(Setting{
- Enabled: true,
- Period: 0,
- Next: now.Add(duration),
- })
- return t, now, tchan
-}
-
// init initializes Timer state that is not preserved across save/restore. If
// init has already been called, calling it again is a no-op.
//
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index 7c4fefb16..6255bae7a 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
// Timekeeper manages all of the kernel clocks.
@@ -39,6 +40,12 @@ type Timekeeper struct {
// It is set only once, by SetClocks.
clocks sentrytime.Clocks `state:"nosave"`
+ // realtimeClock is a ktime.Clock based on timekeeper's Realtime.
+ realtimeClock *timekeeperClock
+
+ // monotonicClock is a ktime.Clock based on timekeeper's Monotonic.
+ monotonicClock *timekeeperClock
+
// bootTime is the realtime when the system "booted". i.e., when
// SetClocks was called in the initial (not restored) run.
bootTime ktime.Time
@@ -90,10 +97,13 @@ type Timekeeper struct {
// NewTimekeeper does not take ownership of paramPage.
//
// SetClocks must be called on the returned Timekeeper before it is usable.
-func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) {
- return &Timekeeper{
+func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) *Timekeeper {
+ t := Timekeeper{
params: NewVDSOParamPage(mfp, paramPage),
- }, nil
+ }
+ t.realtimeClock = &timekeeperClock{tk: &t, c: sentrytime.Realtime}
+ t.monotonicClock = &timekeeperClock{tk: &t, c: sentrytime.Monotonic}
+ return &t
}
// SetClocks the backing clock source.
@@ -167,6 +177,32 @@ func (t *Timekeeper) SetClocks(c sentrytime.Clocks) {
}
}
+var _ tcpip.Clock = (*Timekeeper)(nil)
+
+// Now implements tcpip.Clock.
+func (t *Timekeeper) Now() time.Time {
+ nsec, err := t.GetTime(sentrytime.Realtime)
+ if err != nil {
+ panic("timekeeper.GetTime(sentrytime.Realtime): " + err.Error())
+ }
+ return time.Unix(0, nsec)
+}
+
+// NowMonotonic implements tcpip.Clock.
+func (t *Timekeeper) NowMonotonic() tcpip.MonotonicTime {
+ nsec, err := t.GetTime(sentrytime.Monotonic)
+ if err != nil {
+ panic("timekeeper.GetTime(sentrytime.Monotonic): " + err.Error())
+ }
+ var mt tcpip.MonotonicTime
+ return mt.Add(time.Duration(nsec) * time.Nanosecond)
+}
+
+// AfterFunc implements tcpip.Clock.
+func (t *Timekeeper) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ return ktime.TcpipAfterFunc(t.realtimeClock, d, f)
+}
+
// startUpdater starts an update goroutine that keeps the clocks updated.
//
// mu must be held.
diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go
index dfc3c0719..b6039505a 100644
--- a/pkg/sentry/kernel/timekeeper_test.go
+++ b/pkg/sentry/kernel/timekeeper_test.go
@@ -17,12 +17,12 @@ package kernel
import (
"testing"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/syserror"
)
// mockClocks is a sentrytime.Clocks that simply returns the times in the
@@ -45,7 +45,7 @@ func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) {
case sentrytime.Realtime:
return c.realtime, nil
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index 4c65215fa..54bfed644 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -17,8 +17,10 @@ go_library(
deps = [
"//pkg/abi",
"//pkg/abi/linux",
+ "//pkg/abi/linux/errno",
"//pkg/context",
"//pkg/cpuid",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/rand",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 8fc3e2a79..13ab7ea23 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -517,13 +518,13 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in
start, ok = start.AddLength(uint64(offset))
if !ok {
ctx.Infof(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset))
- return loadedELF{}, syserror.EINVAL
+ return loadedELF{}, linuxerr.EINVAL
}
end, ok = end.AddLength(uint64(offset))
if !ok {
ctx.Infof(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset))
- return loadedELF{}, syserror.EINVAL
+ return loadedELF{}, linuxerr.EINVAL
}
info.entry, ok = info.entry.AddLength(uint64(offset))
@@ -621,7 +622,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS
func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) {
info, err := parseHeader(ctx, f)
if err != nil {
- if err == syserror.ENOEXEC {
+ if linuxerr.Equals(linuxerr.ENOEXEC, err) {
// Bad interpreter.
err = syserror.ELIBBAD
}
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
index 3886b4d33..3e302d92c 100644
--- a/pkg/sentry/loader/interpreter.go
+++ b/pkg/sentry/loader/interpreter.go
@@ -59,7 +59,7 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil
// Linux silently truncates the remainder of the line if it exceeds
// interpMaxLineLength.
i := bytes.IndexByte(line, '\n')
- if i > 0 {
+ if i >= 0 {
line = line[:i]
}
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 47e3775a3..8240173ae 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/hostarch"
@@ -237,7 +238,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
// loaded.end is available for its use.
e, ok := loaded.end.RoundUp()
if !ok {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC)
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), errno.ENOEXEC)
}
args.MemoryManager.BrkSetup(ctx, e)
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index fd54261fd..054ef1723 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
@@ -58,7 +59,7 @@ type byteFullReader struct {
func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset >= int64(len(b.data)) {
return 0, io.EOF
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index b417c2da7..69aff21b6 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -125,6 +125,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/atomicbitops",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
@@ -156,6 +157,7 @@ go_test(
library = ":mm",
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/sentry/arch",
"//pkg/sentry/contexttest",
@@ -163,7 +165,6 @@ go_test(
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 346866d3c..8426fc90e 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -17,6 +17,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -158,7 +159,7 @@ func (ctx *AIOContext) Prepare() error {
defer ctx.mu.Unlock()
if ctx.dead {
// Context died after the caller looked it up.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if ctx.outstanding >= ctx.maxOutstanding {
// Context is busy.
@@ -297,7 +298,7 @@ func (m *aioMappable) InodeID() uint64 {
// Msync implements memmap.MappingIdentity.Msync.
func (m *aioMappable) Msync(ctx context.Context, mr memmap.MappableRange) error {
// Linux: aio_ring_fops.fsync == NULL
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// AddMapping implements memmap.Mappable.AddMapping.
@@ -325,7 +326,7 @@ func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, s
// Linux's fs/aio.c:aio_ring_mremap().
mm, ok := ms.(*MemoryManager)
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
am := &mm.aioManager
am.mu.Lock()
@@ -333,12 +334,12 @@ func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, s
oldID := uint64(srcAR.Start)
aioCtx, ok := am.contexts[oldID]
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
aioCtx.mu.Lock()
defer aioCtx.mu.Unlock()
if aioCtx.dead {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Use the new ID for the AIOContext.
am.contexts[uint64(dstAR.Start)] = aioCtx
@@ -399,7 +400,7 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
id := uint64(addr)
if !mm.aioManager.newAIOContext(events, id) {
mm.MUnmap(ctx, addr, aioRingBufferSize)
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
return id, nil
}
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index 1304b0a2f..84cb8158d 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -171,7 +171,7 @@ func TestIOAfterUnmap(t *testing.T) {
}
n, err = mm.CopyIn(ctx, addr, b, usermem.IOOpts{})
- if err != syserror.EFAULT {
+ if !linuxerr.Equals(linuxerr.EFAULT, err) {
t.Errorf("CopyIn got err %v want EFAULT", err)
}
if n != 0 {
@@ -212,7 +212,7 @@ func TestIOAfterMProtect(t *testing.T) {
// Without IgnorePermissions, CopyOut should no longer succeed.
n, err = mm.CopyOut(ctx, addr, b, usermem.IOOpts{})
- if err != syserror.EFAULT {
+ if !linuxerr.Equals(linuxerr.EFAULT, err) {
t.Errorf("CopyOut got err %v want EFAULT", err)
}
if n != 0 {
@@ -249,7 +249,7 @@ func TestAIOPrepareAfterDestroy(t *testing.T) {
mm.DestroyAIOContext(ctx, id)
// Prepare should fail because aioCtx should be destroyed.
- if err := aioCtx.Prepare(); err != syserror.EINVAL {
+ if err := aioCtx.Prepare(); !linuxerr.Equals(linuxerr.EINVAL, err) {
t.Errorf("aioCtx.Prepare got err %v want nil", err)
} else if err == nil {
aioCtx.CancelPendingRequest()
diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go
index 3130be80c..94d5112a1 100644
--- a/pkg/sentry/mm/shm.go
+++ b/pkg/sentry/mm/shm.go
@@ -16,16 +16,16 @@ package mm
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/shm"
- "gvisor.dev/gvisor/pkg/syserror"
)
// DetachShm unmaps a sysv shared memory segment.
func (mm *MemoryManager) DetachShm(ctx context.Context, addr hostarch.Addr) error {
if addr != addr.RoundDown() {
// "... shmaddr is not aligned on a page boundary." - man shmdt(2)
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
var detached *shm.Shm
@@ -48,7 +48,7 @@ func (mm *MemoryManager) DetachShm(ctx context.Context, addr hostarch.Addr) erro
if detached == nil {
// There is no shared memory segment attached at addr.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Remove all vmas that could have been created by the same attach.
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index e748b7ff8..feafe76c1 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -16,6 +16,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -144,11 +145,11 @@ func (m *SpecialMappable) Length() uint64 {
// leak (b/143656263). Delete this function along with VFS1.
func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) {
if length == 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
alignedLen, ok := hostarch.Addr(length).RoundUp()
if !ok {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
fr, err := mfp.MemoryFile().Allocate(uint64(alignedLen), usage.Anonymous)
if err != nil {
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 7ad6b7c21..7b6715815 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
@@ -74,7 +75,7 @@ func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr hostarch.Addr
// MMap establishes a memory mapping.
func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostarch.Addr, error) {
if opts.Length == 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
length, ok := hostarch.Addr(opts.Length).RoundUp()
if !ok {
@@ -85,7 +86,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar
if opts.Mappable != nil {
// Offset must be aligned.
if hostarch.Addr(opts.Offset).RoundDown() != hostarch.Addr(opts.Offset) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Offset + length must not overflow.
if end := opts.Offset + opts.Length; end < opts.Offset {
@@ -99,7 +100,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar
// MAP_FIXED requires addr to be page-aligned; non-fixed mappings
// don't.
if opts.Fixed {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
opts.Addr = opts.Addr.RoundDown()
}
@@ -108,10 +109,10 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar
return 0, syserror.EACCES
}
if opts.Unmap && !opts.Fixed {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if opts.GrowsDown && opts.Mappable != nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Get the new vma.
@@ -281,18 +282,18 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, erro
// MUnmap implements the semantics of Linux's munmap(2).
func (mm *MemoryManager) MUnmap(ctx context.Context, addr hostarch.Addr, length uint64) error {
if addr != addr.RoundDown() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if length == 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
la, ok := hostarch.Addr(length).RoundUp()
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
ar, ok := addr.ToRange(uint64(la))
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
mm.mappingMu.Lock()
@@ -331,7 +332,7 @@ const (
func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (hostarch.Addr, error) {
// "Note that old_address has to be page aligned." - mremap(2)
if oldAddr.RoundDown() != oldAddr {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Linux treats an old_size that rounds up to 0 as 0, which is otherwise a
@@ -340,13 +341,13 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS
oldSize = uint64(oldSizeAddr)
newSizeAddr, ok := hostarch.Addr(newSize).RoundUp()
if !ok || newSizeAddr == 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
newSize = uint64(newSizeAddr)
oldEnd, ok := oldAddr.AddLength(oldSize)
if !ok {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
mm.mappingMu.Lock()
@@ -450,15 +451,15 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS
case MRemapMustMove:
newAddr := opts.NewAddr
if newAddr.RoundDown() != newAddr {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
var ok bool
newAR, ok = newAddr.ToRange(newSize)
if !ok {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if (hostarch.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Check that the new region is valid.
@@ -504,7 +505,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS
if vma := vseg.ValuePtr(); vma.mappable != nil {
// Check that offset+length does not overflow.
if vma.off+uint64(newAR.Length()) < vma.off {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Inform the Mappable, if any, of the new mapping.
if err := vma.mappable.CopyMapping(ctx, mm, oldAR, newAR, vseg.mappableOffsetAt(oldAR.Start), vma.canWriteMappableLocked()); err != nil {
@@ -590,7 +591,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS
// MProtect implements the semantics of Linux's mprotect(2).
func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms hostarch.AccessType, growsDown bool) error {
if addr.RoundDown() != addr {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if length == 0 {
return nil
@@ -618,7 +619,7 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h
}
if growsDown {
if !vseg.ValuePtr().growsDown {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if ar.End <= vseg.Start() {
return syserror.ENOMEM
@@ -711,7 +712,7 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr hostarch.Addr) (hostarch.
if addr < mm.brk.Start {
addr = mm.brk.End
mm.mappingMu.Unlock()
- return addr, syserror.EINVAL
+ return addr, linuxerr.EINVAL
}
// TODO(gvisor.dev/issue/156): This enforces RLIMIT_DATA, but is
@@ -780,7 +781,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u
la, _ := hostarch.Addr(length + addr.PageOffset()).RoundUp()
ar, ok := addr.RoundDown().ToRange(uint64(la))
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
mm.mappingMu.Lock()
@@ -855,10 +856,10 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u
mm.activeMu.Unlock()
mm.mappingMu.RUnlock()
// Linux: mm/mlock.c:__mlock_posix_error_return()
- if err == syserror.EFAULT {
+ if linuxerr.Equals(linuxerr.EFAULT, err) {
return syserror.ENOMEM
}
- if err == syserror.ENOMEM {
+ if linuxerr.Equals(linuxerr.ENOMEM, err) {
return syserror.EAGAIN
}
return err
@@ -898,7 +899,7 @@ type MLockAllOpts struct {
// depending on opts.
func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error {
if !opts.Current && !opts.Future {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
mm.mappingMu.Lock()
@@ -979,13 +980,13 @@ func (mm *MemoryManager) NumaPolicy(addr hostarch.Addr) (linux.NumaPolicy, uint6
// SetNumaPolicy implements the semantics of Linux's mbind().
func (mm *MemoryManager) SetNumaPolicy(addr hostarch.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error {
if !addr.IsPageAligned() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Linux allows this to overflow.
la, _ := hostarch.Addr(length).RoundUp()
ar, ok := addr.ToRange(uint64(la))
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if ar.Length() == 0 {
return nil
@@ -1021,7 +1022,7 @@ func (mm *MemoryManager) SetNumaPolicy(addr hostarch.Addr, length uint64, policy
func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork bool) error {
ar, ok := addr.ToRange(length)
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
mm.mappingMu.Lock()
@@ -1047,7 +1048,7 @@ func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork
func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error {
ar, ok := addr.ToRange(length)
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
mm.mappingMu.RLock()
@@ -1063,7 +1064,7 @@ func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error {
for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
vma := vseg.ValuePtr()
if vma.mlockMode != memmap.MLockNone {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
vsegAR := vseg.Range().Intersect(ar)
// pseg should already correspond to either this vma or a later one,
@@ -1114,7 +1115,7 @@ type MSyncOpts struct {
// MSync implements the semantics of Linux's msync().
func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length uint64, opts MSyncOpts) error {
if addr != addr.RoundDown() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if length == 0 {
return nil
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index b81292c46..d1a883da4 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -1062,10 +1062,20 @@ func (f *MemoryFile) runReclaim() {
break
}
- // If ManualZeroing is in effect, pages will be zeroed on allocation
- // and may not be freed by decommitFile, so calling decommitFile is
- // unnecessary.
- if !f.opts.ManualZeroing {
+ if f.opts.ManualZeroing {
+ // If ManualZeroing is in effect, only hugepage-aligned regions may
+ // be safely passed to decommitFile. Pages will be zeroed on
+ // reallocation, so we don't need to perform any manual zeroing
+ // here, whether or not decommitFile succeeds.
+ if startAddr, ok := hostarch.Addr(fr.Start).HugeRoundUp(); ok {
+ if endAddr := hostarch.Addr(fr.End).HugeRoundDown(); startAddr < endAddr {
+ decommitFR := memmap.FileRange{uint64(startAddr), uint64(endAddr)}
+ if err := f.decommitFile(decommitFR); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", decommitFR, err)
+ }
+ }
+ }
+ } else {
if err := f.decommitFile(fr); err != nil {
log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
// Zero the pages manually. This won't reduce memory usage, but at
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index b307832fd..8a490b3de 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -6,6 +6,8 @@ go_library(
name = "kvm",
srcs = [
"address_space.go",
+ "address_space_amd64.go",
+ "address_space_arm64.go",
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
@@ -77,6 +79,7 @@ go_test(
"requires-kvm",
],
deps = [
+ "//pkg/abi/linux",
"//pkg/hostarch",
"//pkg/ring0",
"//pkg/ring0/pagetables",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index 5524e8727..9929caebb 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -85,15 +85,6 @@ type addressSpace struct {
dirtySet *dirtySet
}
-// invalidate is the implementation for Invalidate.
-func (as *addressSpace) invalidate() {
- as.dirtySet.forEach(as.machine, func(c *vCPU) {
- if c.active.get() == as { // If this happens to be active,
- c.BounceToKernel() // ... force a kernel transition.
- }
- })
-}
-
// Invalidate interrupts all dirty contexts.
func (as *addressSpace) Invalidate() {
as.mu.Lock()
diff --git a/pkg/flipcall/packet_window_mmap_arm64.go b/pkg/sentry/platform/kvm/address_space_amd64.go
index 87ad1a4a1..d11d38679 100644
--- a/pkg/flipcall/packet_window_mmap_arm64.go
+++ b/pkg/sentry/platform/kvm/address_space_amd64.go
@@ -1,4 +1,4 @@
-// Copyright 2020 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,14 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build arm64
+package kvm
-package flipcall
-
-import "golang.org/x/sys/unix"
-
-// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
-func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, unix.Errno) {
- m, _, err := unix.RawSyscall6(unix.SYS_MMAP, 0, uintptr(pwd.Length), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
- return m, err
+// invalidate is the implementation for Invalidate.
+func (as *addressSpace) invalidate() {
+ as.dirtySet.forEach(as.machine, func(c *vCPU) {
+ if c.active.get() == as { // If this happens to be active,
+ c.BounceToKernel() // ... force a kernel transition.
+ }
+ })
}
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/sentry/platform/kvm/address_space_arm64.go
index 2bf21a2e7..fb954418b 100644
--- a/pkg/tcpip/transport/tcp/rcv_state.go
+++ b/pkg/sentry/platform/kvm/address_space_arm64.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcp
+package kvm
import (
- "time"
+ "gvisor.dev/gvisor/pkg/ring0"
)
-// saveLastRcvdAckTime is invoked by stateify.
-func (r *receiver) saveLastRcvdAckTime() unixTime {
- return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()}
-}
-
-// loadLastRcvdAckTime is invoked by stateify.
-func (r *receiver) loadLastRcvdAckTime(unix unixTime) {
- r.lastRcvdAckTime = time.Unix(unix.second, unix.nano)
+// invalidate is the implementation for Invalidate.
+func (as *addressSpace) invalidate() {
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+ ring0.FlushTlbAll()
}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 578852c3f..9e5c52923 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -25,29 +25,6 @@ import (
var (
// The action for bluepillSignal is changed by sigaction().
bluepillSignal = unix.SIGILL
-
- // vcpuSErrBounce is the event of system error for bouncing KVM.
- vcpuSErrBounce = kvmVcpuEvents{
- exception: exception{
- sErrPending: 1,
- },
- }
-
- // vcpuSErrNMI is the event of system error to trigger sigbus.
- vcpuSErrNMI = kvmVcpuEvents{
- exception: exception{
- sErrPending: 1,
- sErrHasEsr: 1,
- sErrEsr: _ESR_ELx_SERR_NMI,
- },
- }
-
- // vcpuExtDabt is the event of ext_dabt.
- vcpuExtDabt = kvmVcpuEvents{
- exception: exception{
- extDabtPending: 1,
- },
- }
)
// getTLS returns the value of TPIDR_EL0 register.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 07fc4f216..f105fdbd0 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -80,11 +80,18 @@ func getHypercallID(addr uintptr) int {
//
//go:nosplit
func bluepillStopGuest(c *vCPU) {
+ // vcpuSErrBounce is the event of system error for bouncing KVM.
+ vcpuSErrBounce := &kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ },
+ }
+
if _, _, errno := unix.RawSyscall( // escapes: no.
unix.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
- uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
+ uintptr(unsafe.Pointer(vcpuSErrBounce))); errno != 0 {
throw("bounce sErr injection failed")
}
}
@@ -93,12 +100,21 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillSigBus(c *vCPU) {
+ // vcpuSErrNMI is the event of system error to trigger sigbus.
+ vcpuSErrNMI := &kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ sErrHasEsr: 1,
+ sErrEsr: _ESR_ELx_SERR_NMI,
+ },
+ }
+
// Host must support ARM64_HAS_RAS_EXTN.
if _, _, errno := unix.RawSyscall( // escapes: no.
unix.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
- uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
+ uintptr(unsafe.Pointer(vcpuSErrNMI))); errno != 0 {
if errno == unix.EINVAL {
throw("No ARM64_HAS_RAS_EXTN feature in host.")
}
@@ -110,11 +126,18 @@ func bluepillSigBus(c *vCPU) {
//
//go:nosplit
func bluepillExtDabt(c *vCPU) {
+ // vcpuExtDabt is the event of ext_dabt.
+ vcpuExtDabt := &kvmVcpuEvents{
+ exception: exception{
+ extDabtPending: 1,
+ },
+ }
+
if _, _, errno := unix.RawSyscall( // escapes: no.
unix.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
- uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 {
+ uintptr(unsafe.Pointer(vcpuExtDabt))); errno != 0 {
throw("ext_dabt injection failed")
}
}
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index f4d4473a8..183e741ea 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -17,6 +17,7 @@ package kvm
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
pkgcontext "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
@@ -32,15 +33,15 @@ type context struct {
// machine is the parent machine, and is immutable.
machine *machine
- // info is the arch.SignalInfo cached for this context.
- info arch.SignalInfo
+ // info is the linux.SignalInfo cached for this context.
+ info linux.SignalInfo
// interrupt is the interrupt context.
interrupt interrupt.Forwarder
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, hostarch.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*linux.SignalInfo, hostarch.AccessType, error) {
as := mm.AddressSpace()
localAS := as.(*addressSpace)
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index b8dd1e4a5..b1cab89a0 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -19,6 +19,7 @@ package kvm
import (
"testing"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -30,7 +31,7 @@ func TestSegments(t *testing.T) {
applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestSegments(regs)
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -55,7 +56,7 @@ func stmxcsr(addr *uint32)
func TestMXCSR(t *testing.T) {
applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
switchOpts := ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index ceff09a60..fe570aff9 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -22,6 +22,7 @@ import (
"time"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
@@ -157,7 +158,7 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func
func TestApplicationSyscall(t *testing.T) {
applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -171,7 +172,7 @@ func TestApplicationSyscall(t *testing.T) {
return false
})
applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -188,7 +189,7 @@ func TestApplicationSyscall(t *testing.T) {
func TestApplicationFault(t *testing.T) {
applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -203,7 +204,7 @@ func TestApplicationFault(t *testing.T) {
})
applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -221,7 +222,7 @@ func TestRegistersSyscall(t *testing.T) {
applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -244,7 +245,7 @@ func TestRegistersFault(t *testing.T) {
applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -270,7 +271,7 @@ func TestBounce(t *testing.T) {
time.Sleep(time.Millisecond)
c.BounceToKernel()
}()
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -285,7 +286,7 @@ func TestBounce(t *testing.T) {
time.Sleep(time.Millisecond)
c.BounceToKernel()
}()
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -317,7 +318,7 @@ func TestBounceStress(t *testing.T) {
c.BounceToKernel()
}()
randomSleep()
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -338,7 +339,7 @@ func TestInvalidate(t *testing.T) {
applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, &data) // Read legitimate value.
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -353,7 +354,7 @@ func TestInvalidate(t *testing.T) {
// Unmap the page containing data & invalidate.
pt.Unmap(hostarch.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(hostarch.PageSize-1)), hostarch.PageSize)
for {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -371,13 +372,13 @@ func TestInvalidate(t *testing.T) {
}
// IsFault returns true iff the given signal represents a fault.
-func IsFault(err error, si *arch.SignalInfo) bool {
+func IsFault(err error, si *linux.SignalInfo) bool {
return err == platform.ErrContextSignal && si.Signo == int32(unix.SIGSEGV)
}
func TestEmptyAddressSpace(t *testing.T) {
applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -391,7 +392,7 @@ func TestEmptyAddressSpace(t *testing.T) {
return false
})
applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -467,7 +468,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
a int // Count for ErrContextInterrupt.
)
applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
@@ -504,7 +505,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
a int
)
applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- var si arch.SignalInfo
+ var si linux.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
FloatingPointState: &dummyFPState,
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 9a2337654..7a10fd812 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -23,11 +23,11 @@ import (
"runtime/debug"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
@@ -264,10 +264,10 @@ func (c *vCPU) setSystemTime() error {
// nonCanonical generates a canonical address return.
//
//go:nosplit
-func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
- *info = arch.SignalInfo{
+func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
+ *info = linux.SignalInfo{
Signo: signal,
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
info.SetAddr(addr) // Include address.
return hostarch.NoAccess, platform.ErrContextSignal
@@ -276,7 +276,7 @@ func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.Ac
// fault generates an appropriate fault return.
//
//go:nosplit
-func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
bluepill(c) // Probably no-op, but may not be.
faultAddr := ring0.ReadCR2()
code, user := c.ErrorCode()
@@ -287,7 +287,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType,
return hostarch.NoAccess, platform.ErrContextInterrupt
}
// Reset the pointed SignalInfo.
- *info = arch.SignalInfo{Signo: signal}
+ *info = linux.SignalInfo{Signo: signal}
info.SetAddr(uint64(faultAddr))
accessType := hostarch.AccessType{
Read: code&(1<<1) == 0,
@@ -325,7 +325,7 @@ func prefaultFloatingPointState(data *fpu.State) {
}
// SwitchToUser unpacks architectural-details.
-func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) {
return nonCanonical(regs.Rip, int32(unix.SIGSEGV), info)
@@ -371,7 +371,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return c.fault(int32(unix.SIGSEGV), info)
case ring0.Debug, ring0.Breakpoint:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGTRAP),
Code: 1, // TRAP_BRKPT (breakpoint).
}
@@ -383,9 +383,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
ring0.BoundRangeExceeded,
ring0.InvalidTSS,
ring0.StackSegmentFault:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGSEGV),
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
info.SetAddr(switchOpts.Registers.Rip) // Include address.
if vector == ring0.GeneralProtectionFault {
@@ -397,7 +397,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.InvalidOpcode:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGILL),
Code: 1, // ILL_ILLOPC (illegal opcode).
}
@@ -405,7 +405,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.DivideByZero:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGFPE),
Code: 1, // FPE_INTDIV (divide by zero).
}
@@ -413,7 +413,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.Overflow:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGFPE),
Code: 2, // FPE_INTOVF (integer overflow).
}
@@ -422,7 +422,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.X87FloatingPointException,
ring0.SIMDFloatingPointException:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGFPE),
Code: 7, // FPE_FLTINV (invalid operation).
}
@@ -433,7 +433,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return hostarch.NoAccess, platform.ErrContextInterrupt
case ring0.AlignmentCheck:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGBUS),
Code: 2, // BUS_ADRERR (physical address does not exist).
}
@@ -469,7 +469,7 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
}
func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
- // Map all the executible regions so that all the entry functions
+ // Map all the executable regions so that all the entry functions
// are mapped in the upper half.
applyVirtualRegions(func(vr virtualRegion) {
if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
@@ -485,7 +485,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
pageTable.Map(
hostarch.Addr(ring0.KernelStartAddress|r.virtual),
r.length,
- pagetables.MapOpts{AccessType: hostarch.Execute},
+ pagetables.MapOpts{AccessType: hostarch.Execute, Global: true},
physical)
}
})
@@ -498,7 +498,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
pageTable.Map(
hostarch.Addr(ring0.KernelStartAddress|start),
regionLen,
- pagetables.MapOpts{AccessType: hostarch.ReadWrite},
+ pagetables.MapOpts{AccessType: hostarch.ReadWrite, Global: true},
physical)
}
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 8926b1d9f..edaccf9bc 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -21,10 +21,10 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
)
@@ -126,10 +126,10 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
// nonCanonical generates a canonical address return.
//
//go:nosplit
-func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
- *info = arch.SignalInfo{
+func nonCanonical(addr uint64, signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
+ *info = linux.SignalInfo{
Signo: signal,
- Code: arch.SignalInfoKernel,
+ Code: linux.SI_KERNEL,
}
info.SetAddr(addr) // Include address.
return hostarch.NoAccess, platform.ErrContextSignal
@@ -157,7 +157,7 @@ func isWriteFault(code uint64) bool {
// fault generates an appropriate fault return.
//
//go:nosplit
-func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) fault(signal int32, info *linux.SignalInfo) (hostarch.AccessType, error) {
bluepill(c) // Probably no-op, but may not be.
faultAddr := c.GetFaultAddr()
code, user := c.ErrorCode()
@@ -170,7 +170,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType,
}
// Reset the pointed SignalInfo.
- *info = arch.SignalInfo{Signo: signal}
+ *info = linux.SignalInfo{Signo: signal}
info.SetAddr(uint64(faultAddr))
ret := code & _ESR_ELx_FSC
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 92edc992b..f6aa519b1 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -23,10 +23,10 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
@@ -140,22 +140,15 @@ func (c *vCPU) initArchState() error {
// vbar_el1
reg.id = _KVM_ARM64_REGS_VBAR_EL1
-
- fromLocation := reflect.ValueOf(ring0.Vectors).Pointer()
- offset := fromLocation & (1<<11 - 1)
- if offset != 0 {
- offset = 1<<11 - offset
- }
-
- toLocation := fromLocation + offset
- data = uint64(ring0.KernelStartAddress | toLocation)
+ vectorLocation := reflect.ValueOf(ring0.Vectors).Pointer()
+ data = uint64(ring0.KernelStartAddress | vectorLocation)
if err := c.setOneRegister(&reg); err != nil {
return err
}
// Use the address of the exception vector table as
// the MMIO address base.
- arm64HypercallMMIOBase = toLocation
+ arm64HypercallMMIOBase = vectorLocation
// Initialize the PCID database.
if hasGuestPCID {
@@ -272,7 +265,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
}
// SwitchToUser unpacks architectural-details.
-func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) {
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
return nonCanonical(regs.Pc, int32(unix.SIGSEGV), info)
@@ -319,14 +312,14 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.El0SyncUndef:
return c.fault(int32(unix.SIGILL), info)
case ring0.El0SyncDbg:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGTRAP),
Code: 1, // TRAP_BRKPT (breakpoint).
}
info.SetAddr(switchOpts.Registers.Pc) // Include address.
return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.El0SyncSpPc:
- *info = arch.SignalInfo{
+ *info = linux.SignalInfo{
Signo: int32(unix.SIGBUS),
Code: 2, // BUS_ADRERR (physical address does not exist).
}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index ef7814a6f..a26bc2316 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -195,8 +195,8 @@ type Context interface {
// - nil: The Context invoked a system call.
//
// - ErrContextSignal: The Context was interrupted by a signal. The
- // returned *arch.SignalInfo contains information about the signal. If
- // arch.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType
+ // returned *linux.SignalInfo contains information about the signal. If
+ // linux.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType
// contains the access type of the triggering fault. The caller owns
// the returned SignalInfo.
//
@@ -207,7 +207,7 @@ type Context interface {
// concurrent call to Switch().
//
// - ErrContextCPUPreempted: See the definition of that error for details.
- Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error)
+ Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error)
// PullFullState() pulls a full state of the application thread.
//
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 828458ce2..319b0cf1d 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -73,7 +73,7 @@ var (
type context struct {
// signalInfo is the signal info, if and when a signal is received.
- signalInfo arch.SignalInfo
+ signalInfo linux.SignalInfo
// interrupt is the interrupt context.
interrupt interrupt.Forwarder
@@ -96,7 +96,7 @@ type context struct {
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*linux.SignalInfo, hostarch.AccessType, error) {
as := mm.AddressSpace()
s := as.(*subprocess)
isSyscall := s.switchToApp(c, ac)
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index facb96011..cc93396a9 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -101,7 +101,7 @@ func (t *thread) setFPRegs(fpState *fpu.State, fpLen uint64, useXsave bool) erro
}
// getSignalInfo retrieves information about the signal that caused the stop.
-func (t *thread) getSignalInfo(si *arch.SignalInfo) error {
+func (t *thread) getSignalInfo(si *linux.SignalInfo) error {
_, _, errno := unix.RawSyscall6(
unix.SYS_PTRACE,
unix.PTRACE_GETSIGINFO,
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 9c73a725a..0931795c5 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -20,6 +20,7 @@ import (
"runtime"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
@@ -524,7 +525,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// Check for interrupts, and ensure that future interrupts will signal t.
if !c.interrupt.Enable(t) {
// Pending interrupt; simulate.
- c.signalInfo = arch.SignalInfo{Signo: int32(platform.SignalInterrupt)}
+ c.signalInfo = linux.SignalInfo{Signo: int32(platform.SignalInterrupt)}
return false
}
defer c.interrupt.Disable()
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 9252c0bd7..90b1ead56 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -155,7 +155,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
//
// Note that this should only be called after verifying that the signalInfo has
// been generated by the kernel.
-func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) {
if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
signalInfo.Signo = int32(linux.SIGSEGV)
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index c0cbc0686..e4257e3bf 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -138,7 +138,7 @@ func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
//
// Note that this should only be called after verifying that the signalInfo has
// been generated by the kernel.
-func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+func patchSignalInfo(regs *arch.Registers, signalInfo *linux.SignalInfo) {
if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
signalInfo.Signo = int32(linux.SIGSEGV)
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index d6a2fbe34..3fe5c6770 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -21,25 +21,16 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
-// pkg/sentry/arch.
-type sigaction struct {
- handler uintptr
- flags uint64
- restorer uintptr
- mask uint64
-}
-
// IgnoreChildStop sets the SA_NOCLDSTOP flag, causing child processes to not
// generate SIGCHLD when they stop.
func IgnoreChildStop() error {
- var sa sigaction
+ var sa linux.SigAction
// Get the existing signal handler information, and set the flag.
if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), 0, uintptr(unsafe.Pointer(&sa)), linux.SignalSetSize, 0, 0); e != 0 {
return e
}
- sa.flags |= linux.SA_NOCLDSTOP
+ sa.Flags |= linux.SA_NOCLDSTOP
if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
return e
}
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 2029e7cf4..e1d310b1b 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/bits",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 235b9c306..64958b6ec 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
@@ -473,17 +474,17 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
for i := 0; i < len(buf); {
if i+linux.SizeOfControlMessageHeader > len(buf) {
- return cmsgs, syserror.EINVAL
+ return cmsgs, linuxerr.EINVAL
}
var h linux.ControlMessageHeader
h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
if h.Length > uint64(len(buf)-i) {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
i += linux.SizeOfControlMessageHeader
@@ -497,7 +498,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
numRights := rightsSize / linux.SizeOfControlMessageRight
if len(fds)+numRights > linux.SCM_MAX_FD {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
@@ -508,7 +509,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
var creds linux.ControlMessageCredentials
@@ -522,7 +523,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.SO_TIMESTAMP:
if length < linux.SizeOfTimeval {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
var ts linux.Timeval
ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
@@ -532,13 +533,13 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
default:
// Unknown message type.
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
case linux.SOL_IP:
switch h.Type {
case linux.IP_TOS:
if length < linux.SizeOfControlMessageTOS {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
cmsgs.IP.HasTOS = true
var tos primitive.Uint8
@@ -548,7 +549,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.IP_PKTINFO:
if length < linux.SizeOfControlMessageIPPacketInfo {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
cmsgs.IP.HasIPPacketInfo = true
@@ -561,7 +562,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
if length < addr.SizeBytes() {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
cmsgs.IP.OriginalDstAddress = &addr
@@ -570,7 +571,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.IP_RECVERR:
var errCmsg linux.SockErrCMsgIPv4
if length < errCmsg.SizeBytes() {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
@@ -578,13 +579,13 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
i += bits.AlignUp(length, width)
default:
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
case linux.SOL_IPV6:
switch h.Type {
case linux.IPV6_TCLASS:
if length < linux.SizeOfControlMessageTClass {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
cmsgs.IP.HasTClass = true
var tclass primitive.Uint32
@@ -595,7 +596,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
if length < addr.SizeBytes() {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
cmsgs.IP.OriginalDstAddress = &addr
@@ -604,7 +605,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.IPV6_RECVERR:
var errCmsg linux.SockErrCMsgIPv6
if length < errCmsg.SizeBytes() {
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
@@ -612,10 +613,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
i += bits.AlignUp(length, width)
default:
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
default:
- return socket.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 3c6511ead..3950caa0f 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -18,6 +18,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fdnotifier",
"//pkg/hostarch",
"//pkg/log",
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 52ae4bc9c..38cb2c99c 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
@@ -67,23 +68,6 @@ type socketOperations struct {
socketOpsCommon
}
-// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
-//
-// +stateify savable
-type socketOpsCommon struct {
- socket.SendReceiveTimeout
-
- family int // Read-only.
- stype linux.SockType // Read-only.
- protocol int // Read-only.
- queue waiter.Queue
-
- // fd is the host socket fd. It must have O_NONBLOCK, so that operations
- // will return EWOULDBLOCK instead of blocking on the host. This allows us to
- // handle blocking behavior independently in the sentry.
- fd int
-}
-
var _ = socket.Socket(&socketOperations{})
func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
@@ -103,29 +87,6 @@ func newSocketFile(ctx context.Context, family int, stype linux.SockType, protoc
return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil
}
-// Release implements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release(context.Context) {
- fdnotifier.RemoveFD(int32(s.fd))
- unix.Close(s.fd)
-}
-
-// Readiness implements waiter.Waitable.Readiness.
-func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
- return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
-}
-
-// EventRegister implements waiter.Waitable.EventRegister.
-func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- s.queue.EventRegister(e, mask)
- fdnotifier.UpdateFD(int32(s.fd))
-}
-
-// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
- s.queue.EventUnregister(e)
- fdnotifier.UpdateFD(int32(s.fd))
-}
-
// Ioctl implements fs.FileOperations.Ioctl.
func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
return ioctl(ctx, s.fd, io, args)
@@ -177,6 +138,96 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
return int64(n), err
}
+// Socket implements socket.Provider.Socket.
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := stypeflags & linux.SOCK_TYPE_MASK
+ switch stype {
+ case unix.SOCK_STREAM:
+ switch protocol {
+ case 0, unix.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case unix.SOCK_DGRAM:
+ switch protocol {
+ case 0, unix.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := unix.Socket(p.family, int(stype)|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&unix.SOCK_NONBLOCK != 0)
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
+
+// LINT.ThenChange(./socket_vfs2.go)
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socket.SendReceiveTimeout
+
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+ queue waiter.Queue
+
+ // fd is the host socket fd. It must have O_NONBLOCK, so that operations
+ // will return EWOULDBLOCK instead of blocking on the host. This allows us to
+ // handle blocking behavior independently in the sentry.
+ fd int
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release(context.Context) {
+ fdnotifier.RemoveFD(int32(s.fd))
+ unix.Close(s.fd)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
+ s.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
// Connect implements socket.Socket.Connect.
func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
@@ -596,6 +647,17 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
return 0, syserr.ErrInvalidArgument
}
+ // If the src is zero-length, call SENDTO directly with a null buffer in
+ // order to generate poll/epoll notifications.
+ if src.NumBytes() == 0 {
+ sysflags := flags | unix.MSG_DONTWAIT
+ n, _, errno := unix.Syscall6(unix.SYS_SENDTO, uintptr(s.fd), 0, 0, uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
+ if errno != 0 {
+ return 0, syserr.FromError(errno)
+ }
+ return int(n), nil
+ }
+
space := uint64(control.CmsgsSpace(t, controlMessages))
if space > maxControlLen {
space = maxControlLen
@@ -653,7 +715,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
if ch != nil {
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
@@ -709,56 +771,6 @@ type socketProvider struct {
family int
}
-// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // Check that we are using the host network stack.
- stack := t.NetworkContext()
- if stack == nil {
- return nil, nil
- }
- if _, ok := stack.(*Stack); !ok {
- return nil, nil
- }
-
- // Only accept TCP and UDP.
- stype := stypeflags & linux.SOCK_TYPE_MASK
- switch stype {
- case unix.SOCK_STREAM:
- switch protocol {
- case 0, unix.IPPROTO_TCP:
- // ok
- default:
- return nil, nil
- }
- case unix.SOCK_DGRAM:
- switch protocol {
- case 0, unix.IPPROTO_UDP:
- // ok
- default:
- return nil, nil
- }
- default:
- return nil, nil
- }
-
- // Conservatively ignore all flags specified by the application and add
- // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
- // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
- fd, err := unix.Socket(p.family, int(stype)|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, 0)
- if err != nil {
- return nil, syserr.FromError(err)
- }
- return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&unix.SOCK_NONBLOCK != 0)
-}
-
-// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
- // Not supported by AF_INET/AF_INET6.
- return nil, nil, nil
-}
-
-// LINT.ThenChange(./socket_vfs2.go)
-
func init() {
for _, family := range []int{unix.AF_INET, unix.AF_INET6} {
socket.RegisterProvider(family, &socketProvider{family})
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
index d3be2d825..86dc879d5 100644
--- a/pkg/sentry/socket/hostinet/socket_unsafe.go
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -67,7 +67,23 @@ func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArgument
AddressSpaceActive: true,
})
return 0, err
-
+ case unix.SIOCGIFFLAGS:
+ cc := &usermem.IOCopyContext{
+ Ctx: ctx,
+ IO: io,
+ Opts: usermem.IOOpts{
+ AddressSpaceActive: true,
+ },
+ }
+ var ifr linux.IFReq
+ if _, err := ifr.CopyIn(cc, args[2].Pointer()); err != nil {
+ return 0, err
+ }
+ if _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&ifr))); errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ _, err := ifr.CopyOut(cc, args[2].Pointer())
+ return 0, err
default:
return 0, syserror.ENOTTY
}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 61b2c9755..608474fa1 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -25,6 +25,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 6fc7781ad..3f1b4a17b 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,20 +19,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// TODO(gvisor.dev/issue/170): The following per-matcher params should be
-// supported:
-// - Table name
-// - Match size
-// - User size
-// - Hooks
-// - Proto
-// - Family
-
// matchMaker knows how to (un)marshal the matcher named name().
type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
@@ -43,7 +35,7 @@ type matchMaker interface {
// unmarshal converts from the ABI matcher struct to an
// stack.Matcher.
- unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
+ unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
}
type matcher interface {
@@ -94,12 +86,12 @@ func marshalEntryMatch(name string, data []byte) []byte {
return buf
}
-func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
+func unmarshalMatcher(task *kernel.Task, match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
matchMaker, ok := matchMakers[match.Name.String()]
if !ok {
return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
}
- return matchMaker.unmarshal(buf, filter)
+ return matchMaker.unmarshal(task, buf, filter)
}
// targetMaker knows how to (un)marshal a target. Once registered,
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index cb78ef60b..d8bd86292 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -123,7 +124,7 @@ func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTG
return entries, info
}
-func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
+func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
nflog("set entries: setting entries in table %q", replace.Name.String())
// Convert input into a list of rules and their offsets.
@@ -148,23 +149,19 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): We should support more IPTIP
- // filtering fields.
filter, err := filterFromIPTIP(entry.IP)
if err != nil {
nflog("bad iptip: %v", err)
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Matchers and targets can specify
- // that they only work for certain protocols, hooks, tables.
// Get matchers.
matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
if len(optVal) < int(matchersSize) {
nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ matchers, err := parseMatchers(task, filter, optVal[:matchersSize])
if err != nil {
nflog("failed to parse matchers: %v", err)
return nil, syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 5cb7fe4aa..c68230847 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -126,7 +127,7 @@ func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6T
return entries, info
}
-func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
+func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
nflog("set entries: setting entries in table %q", replace.Name.String())
// Convert input into a list of rules and their offsets.
@@ -151,23 +152,19 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): We should support more IPTIP
- // filtering fields.
filter, err := filterFromIP6TIP(entry.IPv6)
if err != nil {
nflog("bad iptip: %v", err)
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Matchers and targets can specify
- // that they only work for certain protocols, hooks, tables.
// Get matchers.
matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry
if len(optVal) < int(matchersSize) {
nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ matchers, err := parseMatchers(task, filter, optVal[:matchersSize])
if err != nil {
nflog("failed to parse matchers: %v", err)
return nil, syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index f42d73178..e3eade180 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -58,8 +58,8 @@ var nameToID = map[string]stack.TableID{
// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
// compatibility with netfilter extensions.
-func DefaultLinuxTables() *stack.IPTables {
- tables := stack.DefaultTables()
+func DefaultLinuxTables(seed uint32) *stack.IPTables {
+ tables := stack.DefaultTables(seed)
tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
switch val := oldTarget.(type) {
case *stack.AcceptTarget:
@@ -174,13 +174,12 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
+func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
replaceBuf := optVal[:linux.SizeOfIPTReplace]
optVal = optVal[linux.SizeOfIPTReplace:]
replace.UnmarshalBytes(replaceBuf)
- // TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
case filterTable:
@@ -188,16 +187,16 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
case natTable:
table = stack.EmptyNATTable()
default:
- nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ nflog("unknown iptables table %q", replace.Name.String())
return syserr.ErrInvalidArgument
}
var err *syserr.Error
var offsets map[uint32]int
if ipv6 {
- offsets, err = modifyEntries6(stk, optVal, &replace, &table)
+ offsets, err = modifyEntries6(task, stk, optVal, &replace, &table)
} else {
- offsets, err = modifyEntries4(stk, optVal, &replace, &table)
+ offsets, err = modifyEntries4(task, stk, optVal, &replace, &table)
}
if err != nil {
return err
@@ -272,7 +271,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
table.Rules[ruleIdx] = rule
}
- // TODO(gvisor.dev/issue/170): Support other chains.
// Since we don't support FORWARD, yet, make sure all other chains point to
// ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
@@ -287,7 +285,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
}
- // TODO(gvisor.dev/issue/170): Check the following conditions:
+ // TODO(gvisor.dev/issue/6167): Check the following conditions:
// - There are no loops.
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
@@ -297,7 +295,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
// only the matchers.
-func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
+func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
nflog("set entries: parsing matchers of size %d", len(optVal))
var matchers []stack.Matcher
for len(optVal) > 0 {
@@ -321,13 +319,13 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
}
// Parse the specific matcher.
- matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
+ matcher, err := unmarshalMatcher(task, match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
if err != nil {
return nil, fmt.Errorf("failed to create matcher: %v", err)
}
matchers = append(matchers, matcher)
- // TODO(gvisor.dev/issue/170): Check the revision field.
+ // TODO(gvisor.dev/issue/6167): Check the revision field.
optVal = optVal[match.MatchSize:]
}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 60845cab3..6eff2ae65 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -40,8 +42,8 @@ func (ownerMarshaler) name() string {
func (ownerMarshaler) marshal(mr matcher) []byte {
matcher := mr.(*OwnerMatcher)
iptOwnerInfo := linux.IPTOwnerInfo{
- UID: matcher.uid,
- GID: matcher.gid,
+ UID: uint32(matcher.uid),
+ GID: uint32(matcher.gid),
}
// Support for UID and GID match.
@@ -63,7 +65,7 @@ func (ownerMarshaler) marshal(mr matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfIPTOwnerInfo {
return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
}
@@ -72,11 +74,12 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.
// exceed what's strictly necessary to hold matchData.
var matchData linux.IPTOwnerInfo
matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo])
- nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+ nflog("parsed IPTOwnerInfo: %+v", matchData)
var owner OwnerMatcher
- owner.uid = matchData.UID
- owner.gid = matchData.GID
+ creds := task.Credentials()
+ owner.uid = creds.UserNamespace.MapToKUID(auth.UID(matchData.UID))
+ owner.gid = creds.UserNamespace.MapToKGID(auth.GID(matchData.GID))
// Check flags.
if matchData.Match&linux.XT_OWNER_UID != 0 {
@@ -97,8 +100,8 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.
// OwnerMatcher matches against a UID and/or GID.
type OwnerMatcher struct {
- uid uint32
- gid uint32
+ uid auth.KUID
+ gid auth.KGID
matchUID bool
matchGID bool
invertUID bool
@@ -113,7 +116,6 @@ func (*OwnerMatcher) name() string {
// Match implements Matcher.Match.
func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// Support only for OUTPUT chain.
- // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
if hook != stack.Output {
return false, true
}
@@ -126,7 +128,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str
var matches bool
// Check for UID match.
if om.matchUID {
- if pkt.Owner.UID() == om.uid {
+ if auth.KUID(pkt.Owner.KUID()) == om.uid {
matches = true
}
if matches == om.invertUID {
@@ -137,7 +139,7 @@ func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ str
// Check for GID match.
if om.matchGID {
matches = false
- if pkt.Owner.GID() == om.gid {
+ if auth.KGID(pkt.Owner.KGID()) == om.gid {
matches = true
}
if matches == om.invertGID {
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index fa5456eee..ea56f39c1 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -331,7 +331,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Check if the flags are valid.
// Also check if we need to map ports or IP.
// For now, redirect target only supports destination port change.
// Port range and IP range are not supported yet.
@@ -340,7 +339,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
return nil, syserr.ErrInvalidArgument
@@ -420,7 +418,6 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/3549): Check for other flags.
// For now, redirect target only supports destination change.
if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags)
@@ -502,7 +499,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
return nil, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
return nil, syserr.ErrInvalidArgument
@@ -594,7 +590,6 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
- // TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
return &acceptTarget{stack.AcceptTarget{
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 95bb9826e..e5b73a976 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -50,7 +51,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTTCP {
return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
}
@@ -95,8 +96,6 @@ func (*TCPMatcher) name() string {
// Match implements Matcher.Match.
func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
- // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
netHeader := header.IPv4(pkt.NetworkHeader().View())
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index fb8be27e6..aa72ee70c 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -50,7 +51,7 @@ func (udpMarshaler) marshal(mr matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTUDP {
return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
}
@@ -92,8 +93,6 @@ func (*UDPMatcher) name() string {
// Match implements Matcher.Match.
func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
- // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
netHeader := header.IPv4(pkt.NetworkHeader().View())
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 64cd263da..ed85404da 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -14,8 +14,10 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/abi/linux/errno",
"//pkg/bits",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 280563d09..d53f23a9a 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -20,7 +20,9 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
@@ -56,7 +58,7 @@ const (
maxSendBufferSize = 4 << 20 // 4MB
)
-var errNoFilter = syserr.New("no filter attached", linux.ENOENT)
+var errNoFilter = syserr.New("no filter attached", errno.ENOENT)
// netlinkSocketDevice is the netlink socket virtual device.
var netlinkSocketDevice = device.NewAnonDevice()
@@ -558,7 +560,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 9561b7c25..e828982eb 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -19,7 +19,9 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/abi/linux/errno",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 0b64a24c3..11f75628c 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -36,7 +36,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
@@ -77,41 +79,59 @@ func mustCreateGauge(name, description string) *tcpip.StatCounter {
// Metrics contains metrics exported by netstack.
var Metrics = tcpip.Stats{
- UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."),
- MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
- DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
+ DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped at the transport layer."),
+ NICs: tcpip.NICStats{
+ UnknownL3ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l3_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L3 protocol."),
+ UnknownL4ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l4_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L4 protocol."),
+ MalformedL4RcvdPackets: mustCreateMetric("/netstack/nic/malformed_l4_received_packets", "Number of packets received that failed L4 header parsing."),
+ Tx: tcpip.NICPacketStats{
+ Packets: mustCreateMetric("/netstack/nic/tx/packets", "Number of packets transmitted."),
+ Bytes: mustCreateMetric("/netstack/nic/tx/bytes", "Number of bytes transmitted."),
+ },
+ Rx: tcpip.NICPacketStats{
+ Packets: mustCreateMetric("/netstack/nic/rx/packets", "Number of packets received."),
+ Bytes: mustCreateMetric("/netstack/nic/rx/bytes", "Number of bytes received."),
+ },
+ DisabledRx: tcpip.NICPacketStats{
+ Packets: mustCreateMetric("/netstack/nic/disabled_rx/packets", "Number of packets received on disabled NICs."),
+ Bytes: mustCreateMetric("/netstack/nic/disabled_rx/bytes", "Number of bytes received on disabled NICs."),
+ },
+ Neighbor: tcpip.NICNeighborStats{
+ UnreachableEntryLookups: mustCreateMetric("/netstack/nic/neighbor/unreachable_entry_loopups", "Number of lookups performed on a neighbor entry in Unreachable state."),
+ },
+ },
ICMP: tcpip.ICMPStats{
V4: tcpip.ICMPv4Stats{
PacketsSent: tcpip.ICMPv4SentPacketStats{
ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."),
- RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."),
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped due to link layer errors."),
+ RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped due to rate limit being exceeded."),
},
PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received."),
},
Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."),
},
@@ -119,40 +139,40 @@ var Metrics = tcpip.Stats{
V6: tcpip.ICMPv6Stats{
PacketsSent: tcpip.ICMPv6SentPacketStats{
ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."),
- MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."),
- MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."),
- MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent."),
+ MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent."),
+ MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."),
+ MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."),
- RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."),
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped due to link layer errors."),
+ RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped due to rate limit being exceeded."),
},
PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."),
- MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."),
- MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."),
- MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."),
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received."),
+ MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received."),
+ MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."),
+ MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."),
},
Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."),
Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."),
@@ -163,23 +183,23 @@ var Metrics = tcpip.Stats{
IGMP: tcpip.IGMPStats{
PacketsSent: tcpip.IGMPSentPacketStats{
IGMPPacketStats: tcpip.IGMPPacketStats{
- MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."),
- V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."),
- V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."),
- LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."),
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent."),
},
- Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped due to link layer errors."),
},
PacketsReceived: tcpip.IGMPReceivedPacketStats{
IGMPPacketStats: tcpip.IGMPPacketStats{
- MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."),
- V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."),
- V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."),
- LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."),
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received."),
},
- Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received that could not be parsed."),
ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."),
- Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received."),
},
},
IP: tcpip.IPStats{
@@ -205,7 +225,8 @@ var Metrics = tcpip.Stats{
LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."),
LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."),
ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."),
- PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not fit within the outgoing MTU."),
+ PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not be forwarded because they could not fit within the outgoing MTU."),
+ HostUnreachable: mustCreateMetric("/netstack/ip/forwarding/host_unreachable", "Number of IP packets received which could not be forwarded due to unresolvable next hop."),
Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."),
},
},
@@ -270,7 +291,7 @@ const DefaultTTL = 64
const sizeOfInt32 int = 4
-var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
+var errStackType = syserr.New("expected but did not receive a netstack.Stack", errno.EINVAL)
// commonEndpoint represents the intersection of a tcpip.Endpoint and a
// transport.Endpoint.
@@ -1126,7 +1147,14 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name,
// TODO(b/64800844): Translate fields once they are added to
// tcpip.TCPInfoOption.
- info := linux.TCPInfo{}
+ info := linux.TCPInfo{
+ State: uint8(v.State),
+ RTO: uint32(v.RTO / time.Microsecond),
+ RTT: uint32(v.RTT / time.Microsecond),
+ RTTVar: uint32(v.RTTVar / time.Microsecond),
+ SndSsthresh: v.SndSsthresh,
+ SndCwnd: v.SndCwnd,
+ }
switch v.CcState {
case tcpip.RTORecovery:
info.CaState = linux.TCP_CA_Loss
@@ -1137,11 +1165,6 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name,
case tcpip.Open:
info.CaState = linux.TCP_CA_Open
}
- info.RTO = uint32(v.RTO / time.Microsecond)
- info.RTT = uint32(v.RTT / time.Microsecond)
- info.RTTVar = uint32(v.RTTVar / time.Microsecond)
- info.SndSsthresh = v.SndSsthresh
- info.SndCwnd = v.SndCwnd
// In netstack reorderSeen is updated only when RACK is enabled.
// We only track whether the reordering is seen, which is
@@ -1777,11 +1800,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := hostarch.ByteOrder.Uint32(optVal)
-
- if v == 0 {
- socket.SetSockOptEmitUnimplementedEvent(t, name)
- }
-
ep.SocketOptions().SetOutOfBandInline(v != 0)
return nil
@@ -2089,10 +2107,10 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true)
+ return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true)
case linux.IP6T_SO_SET_ADD_COUNTERS:
- // TODO(gvisor.dev/issue/170): Counter support.
+ log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported")
return nil
default:
@@ -2332,10 +2350,10 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false)
+ return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false)
case linux.IPT_SO_SET_ADD_COUNTERS:
- // TODO(gvisor.dev/issue/170): Counter support.
+ log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported")
return nil
case linux.IP_ADD_SOURCE_MEMBERSHIP,
@@ -2792,7 +2810,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if n > 0 {
return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil
}
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
@@ -2860,7 +2878,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
// became available between when we last checked and when we setup
// the notification.
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return int(total), syserr.ErrTryAgain
}
// handleIOError will consume errors from t.Block if needed.
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index eef5e6519..9d343b671 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserr"
@@ -110,19 +111,19 @@ func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) {
switch addr.Family {
case linux.AF_INET:
if len(addr.Addr) != header.IPv4AddressSize {
- return protocolAddress, syserror.EINVAL
+ return protocolAddress, linuxerr.EINVAL
}
if addr.PrefixLen > header.IPv4AddressSize*8 {
- return protocolAddress, syserror.EINVAL
+ return protocolAddress, linuxerr.EINVAL
}
protocol = ipv4.ProtocolNumber
address = tcpip.Address(addr.Addr)
case linux.AF_INET6:
if len(addr.Addr) != header.IPv6AddressSize {
- return protocolAddress, syserror.EINVAL
+ return protocolAddress, linuxerr.EINVAL
}
if addr.PrefixLen > header.IPv6AddressSize*8 {
- return protocolAddress, syserror.EINVAL
+ return protocolAddress, linuxerr.EINVAL
}
protocol = ipv6.ProtocolNumber
address = tcpip.Address(addr.Addr)
diff --git a/pkg/sentry/socket/netstack/tun.go b/pkg/sentry/socket/netstack/tun.go
index 288dd0c9e..e67fe9700 100644
--- a/pkg/sentry/socket/netstack/tun.go
+++ b/pkg/sentry/socket/netstack/tun.go
@@ -16,7 +16,7 @@ package netstack
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
)
@@ -40,8 +40,8 @@ func LinuxToTUNFlags(flags uint16) (tun.Flags, error) {
// Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately)
// when there is no sk_filter. See __tun_chr_ioctl() in
// net/drivers/tun.c.
- if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI) != 0 {
- return tun.Flags{}, syserror.EINVAL
+ if flags&^uint16(linux.IFF_TUN|linux.IFF_TAP|linux.IFF_NO_PI|linux.IFF_ONE_QUEUE) != 0 {
+ return tun.Flags{}, linuxerr.EINVAL
}
return tun.Flags{
TUN: flags&linux.IFF_TUN != 0,
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 353f4ade0..f5da3c509 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -659,7 +659,6 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(sockAddrInet6Size)
case linux.AF_PACKET:
- // TODO(gvisor.dev/issue/173): Return protocol too.
var out linux.SockAddrLink
out.Family = linux.AF_PACKET
out.InterfaceIndex = int32(addr.NIC)
@@ -749,7 +748,6 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- // TODO(gvisor.dev/issue/173): Return protocol too.
return tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index c9cbefb3a..5c3cdef6a 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -39,6 +39,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/hostarch",
"//pkg/log",
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index db7b1affe..8ccdadae9 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -23,6 +23,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
@@ -518,7 +519,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
@@ -719,7 +720,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if total > 0 {
err = nil
}
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index c39e317ff..08a00a12f 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -17,6 +17,7 @@ package unix
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
@@ -236,7 +237,7 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()),
Endpoint: bep,
})
- if err == syserror.EEXIST {
+ if linuxerr.Equals(linuxerr.EEXIST, err) {
return syserr.ErrAddressInUse
}
return syserr.FromError(err)
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
index 3e801182c..7f02807c5 100644
--- a/pkg/sentry/state/BUILD
+++ b/pkg/sentry/state/BUILD
@@ -13,6 +13,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
@@ -20,7 +21,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/state/statefile",
- "//pkg/syserror",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index 167754537..e9d544f3d 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -20,6 +20,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/state/statefile"
- "gvisor.dev/gvisor/pkg/syserror"
)
var previousMetadata map[string]string
@@ -88,7 +88,7 @@ func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Wat
// ENOSPC is a state file error. This error can only come from
// writing the state file, and not from fs.FileOperations.Fsync
// because we wrap those in kernel.TaskSet.flushWritesToFiles.
- if err == syserror.ENOSPC {
+ if linuxerr.Equals(linuxerr.ENOSPC, err) {
err = ErrStateFile{err}
}
@@ -110,7 +110,7 @@ type LoadOpts struct {
}
// Load loads the given kernel, setting the provided platform and stack.
-func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
+func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, timeReady chan struct{}, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
// Open the file.
r, m, err := statefile.NewReader(opts.Source, opts.Key)
if err != nil {
@@ -120,5 +120,5 @@ func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, c
previousMetadata = m
// Restore the Kernel object graph.
- return k.LoadFrom(ctx, r, n, clocks, vfsOpts)
+ return k.LoadFrom(ctx, r, timeReady, n, clocks, vfsOpts)
}
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 1fbbd133c..369541c7a 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -11,6 +11,7 @@ go_library(
"futex.go",
"linux64_amd64.go",
"linux64_arm64.go",
+ "mmap.go",
"open.go",
"poll.go",
"ptrace.go",
@@ -35,7 +36,6 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/syscalls/linux",
- "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/strace/clone.go b/pkg/sentry/strace/clone.go
index ab1060426..bfb4d7f5c 100644
--- a/pkg/sentry/strace/clone.go
+++ b/pkg/sentry/strace/clone.go
@@ -15,98 +15,98 @@
package strace
import (
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
)
// CloneFlagSet is the set of clone(2) flags.
var CloneFlagSet = abi.FlagSet{
{
- Flag: unix.CLONE_VM,
+ Flag: linux.CLONE_VM,
Name: "CLONE_VM",
},
{
- Flag: unix.CLONE_FS,
+ Flag: linux.CLONE_FS,
Name: "CLONE_FS",
},
{
- Flag: unix.CLONE_FILES,
+ Flag: linux.CLONE_FILES,
Name: "CLONE_FILES",
},
{
- Flag: unix.CLONE_SIGHAND,
+ Flag: linux.CLONE_SIGHAND,
Name: "CLONE_SIGHAND",
},
{
- Flag: unix.CLONE_PTRACE,
+ Flag: linux.CLONE_PTRACE,
Name: "CLONE_PTRACE",
},
{
- Flag: unix.CLONE_VFORK,
+ Flag: linux.CLONE_VFORK,
Name: "CLONE_VFORK",
},
{
- Flag: unix.CLONE_PARENT,
+ Flag: linux.CLONE_PARENT,
Name: "CLONE_PARENT",
},
{
- Flag: unix.CLONE_THREAD,
+ Flag: linux.CLONE_THREAD,
Name: "CLONE_THREAD",
},
{
- Flag: unix.CLONE_NEWNS,
+ Flag: linux.CLONE_NEWNS,
Name: "CLONE_NEWNS",
},
{
- Flag: unix.CLONE_SYSVSEM,
+ Flag: linux.CLONE_SYSVSEM,
Name: "CLONE_SYSVSEM",
},
{
- Flag: unix.CLONE_SETTLS,
+ Flag: linux.CLONE_SETTLS,
Name: "CLONE_SETTLS",
},
{
- Flag: unix.CLONE_PARENT_SETTID,
+ Flag: linux.CLONE_PARENT_SETTID,
Name: "CLONE_PARENT_SETTID",
},
{
- Flag: unix.CLONE_CHILD_CLEARTID,
+ Flag: linux.CLONE_CHILD_CLEARTID,
Name: "CLONE_CHILD_CLEARTID",
},
{
- Flag: unix.CLONE_DETACHED,
+ Flag: linux.CLONE_DETACHED,
Name: "CLONE_DETACHED",
},
{
- Flag: unix.CLONE_UNTRACED,
+ Flag: linux.CLONE_UNTRACED,
Name: "CLONE_UNTRACED",
},
{
- Flag: unix.CLONE_CHILD_SETTID,
+ Flag: linux.CLONE_CHILD_SETTID,
Name: "CLONE_CHILD_SETTID",
},
{
- Flag: unix.CLONE_NEWUTS,
+ Flag: linux.CLONE_NEWUTS,
Name: "CLONE_NEWUTS",
},
{
- Flag: unix.CLONE_NEWIPC,
+ Flag: linux.CLONE_NEWIPC,
Name: "CLONE_NEWIPC",
},
{
- Flag: unix.CLONE_NEWUSER,
+ Flag: linux.CLONE_NEWUSER,
Name: "CLONE_NEWUSER",
},
{
- Flag: unix.CLONE_NEWPID,
+ Flag: linux.CLONE_NEWPID,
Name: "CLONE_NEWPID",
},
{
- Flag: unix.CLONE_NEWNET,
+ Flag: linux.CLONE_NEWNET,
Name: "CLONE_NEWNET",
},
{
- Flag: unix.CLONE_IO,
+ Flag: linux.CLONE_IO,
Name: "CLONE_IO",
},
}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index d66befe81..6ce1bb592 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -33,7 +33,7 @@ var linuxAMD64 = SyscallMap{
6: makeSyscallInfo("lstat", Path, Stat),
7: makeSyscallInfo("poll", PollFDs, Hex, Hex),
8: makeSyscallInfo("lseek", Hex, Hex, Hex),
- 9: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 9: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex),
10: makeSyscallInfo("mprotect", Hex, Hex, Hex),
11: makeSyscallInfo("munmap", Hex, Hex),
12: makeSyscallInfo("brk", Hex),
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index 1a2d7d75f..ce5594301 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -246,7 +246,7 @@ var linuxARM64 = SyscallMap{
219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
- 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 222: makeSyscallInfo("mmap", Hex, Hex, MmapProt, MmapFlags, FD, Hex),
223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
224: makeSyscallInfo("swapon", Hex, Hex),
225: makeSyscallInfo("swapoff", Hex),
diff --git a/pkg/sentry/strace/mmap.go b/pkg/sentry/strace/mmap.go
new file mode 100644
index 000000000..0035be586
--- /dev/null
+++ b/pkg/sentry/strace/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// ProtectionFlagSet represents the protection to mmap(2).
+var ProtectionFlagSet = abi.FlagSet{
+ {
+ Flag: linux.PROT_READ,
+ Name: "PROT_READ",
+ },
+ {
+ Flag: linux.PROT_WRITE,
+ Name: "PROT_WRITE",
+ },
+ {
+ Flag: linux.PROT_EXEC,
+ Name: "PROT_EXEC",
+ },
+}
+
+// MmapFlagSet is the set of mmap(2) flags.
+var MmapFlagSet = abi.FlagSet{
+ {
+ Flag: linux.MAP_SHARED,
+ Name: "MAP_SHARED",
+ },
+ {
+ Flag: linux.MAP_PRIVATE,
+ Name: "MAP_PRIVATE",
+ },
+ {
+ Flag: linux.MAP_FIXED,
+ Name: "MAP_FIXED",
+ },
+ {
+ Flag: linux.MAP_ANONYMOUS,
+ Name: "MAP_ANONYMOUS",
+ },
+ {
+ Flag: linux.MAP_GROWSDOWN,
+ Name: "MAP_GROWSDOWN",
+ },
+ {
+ Flag: linux.MAP_DENYWRITE,
+ Name: "MAP_DENYWRITE",
+ },
+ {
+ Flag: linux.MAP_EXECUTABLE,
+ Name: "MAP_EXECUTABLE",
+ },
+ {
+ Flag: linux.MAP_LOCKED,
+ Name: "MAP_LOCKED",
+ },
+ {
+ Flag: linux.MAP_NORESERVE,
+ Name: "MAP_NORESERVE",
+ },
+ {
+ Flag: linux.MAP_POPULATE,
+ Name: "MAP_POPULATE",
+ },
+ {
+ Flag: linux.MAP_NONBLOCK,
+ Name: "MAP_NONBLOCK",
+ },
+ {
+ Flag: linux.MAP_STACK,
+ Name: "MAP_STACK",
+ },
+ {
+ Flag: linux.MAP_HUGETLB,
+ Name: "MAP_HUGETLB",
+ },
+}
diff --git a/pkg/sentry/strace/open.go b/pkg/sentry/strace/open.go
index 5769360da..e7c7649f4 100644
--- a/pkg/sentry/strace/open.go
+++ b/pkg/sentry/strace/open.go
@@ -15,61 +15,61 @@
package strace
import (
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
)
// OpenMode represents the mode to open(2) a file.
var OpenMode = abi.ValueSet{
- unix.O_RDWR: "O_RDWR",
- unix.O_WRONLY: "O_WRONLY",
- unix.O_RDONLY: "O_RDONLY",
+ linux.O_RDWR: "O_RDWR",
+ linux.O_WRONLY: "O_WRONLY",
+ linux.O_RDONLY: "O_RDONLY",
}
// OpenFlagSet is the set of open(2) flags.
var OpenFlagSet = abi.FlagSet{
{
- Flag: unix.O_APPEND,
+ Flag: linux.O_APPEND,
Name: "O_APPEND",
},
{
- Flag: unix.O_ASYNC,
+ Flag: linux.O_ASYNC,
Name: "O_ASYNC",
},
{
- Flag: unix.O_CLOEXEC,
+ Flag: linux.O_CLOEXEC,
Name: "O_CLOEXEC",
},
{
- Flag: unix.O_CREAT,
+ Flag: linux.O_CREAT,
Name: "O_CREAT",
},
{
- Flag: unix.O_DIRECT,
+ Flag: linux.O_DIRECT,
Name: "O_DIRECT",
},
{
- Flag: unix.O_DIRECTORY,
+ Flag: linux.O_DIRECTORY,
Name: "O_DIRECTORY",
},
{
- Flag: unix.O_EXCL,
+ Flag: linux.O_EXCL,
Name: "O_EXCL",
},
{
- Flag: unix.O_NOATIME,
+ Flag: linux.O_NOATIME,
Name: "O_NOATIME",
},
{
- Flag: unix.O_NOCTTY,
+ Flag: linux.O_NOCTTY,
Name: "O_NOCTTY",
},
{
- Flag: unix.O_NOFOLLOW,
+ Flag: linux.O_NOFOLLOW,
Name: "O_NOFOLLOW",
},
{
- Flag: unix.O_NONBLOCK,
+ Flag: linux.O_NONBLOCK,
Name: "O_NONBLOCK",
},
{
@@ -77,18 +77,22 @@ var OpenFlagSet = abi.FlagSet{
Name: "O_PATH",
},
{
- Flag: unix.O_SYNC,
+ Flag: linux.O_SYNC,
Name: "O_SYNC",
},
{
- Flag: unix.O_TRUNC,
+ Flag: linux.O_TMPFILE,
+ Name: "O_TMPFILE",
+ },
+ {
+ Flag: linux.O_TRUNC,
Name: "O_TRUNC",
},
}
func open(val uint64) string {
- s := OpenMode.Parse(val & unix.O_ACCMODE)
- if flags := OpenFlagSet.Parse(val &^ unix.O_ACCMODE); flags != "" {
+ s := OpenMode.Parse(val & linux.O_ACCMODE)
+ if flags := OpenFlagSet.Parse(val &^ linux.O_ACCMODE); flags != "" {
s += "|" + flags
}
return s
diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go
index e5b379a20..5afc9525b 100644
--- a/pkg/sentry/strace/signal.go
+++ b/pkg/sentry/strace/signal.go
@@ -130,8 +130,8 @@ func sigAction(t *kernel.Task, addr hostarch.Addr) string {
return "null"
}
- sa, err := t.CopyInSignalAct(addr)
- if err != nil {
+ var sa linux.SigAction
+ if _, err := sa.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error copying sigaction: %v)", addr, err)
}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index ec5d5f846..af7088847 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -489,6 +489,10 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
case SelectFDSet:
output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
+ case MmapProt:
+ output = append(output, ProtectionFlagSet.Parse(uint64(args[arg].Uint())))
+ case MmapFlags:
+ output = append(output, MmapFlagSet.Parse(uint64(args[arg].Uint())))
case Oct:
output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8))
case Hex:
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 7e69b9279..5893443a7 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -238,6 +238,12 @@ const (
// EpollEvents is an array of struct epoll_event. It is the events
// argument in epoll_wait(2)/epoll_pwait(2).
EpollEvents
+
+ // MmapProt is the protection argument in mmap(2).
+ MmapProt
+
+ // MmapFlags is the flags argument in mmap(2).
+ MmapFlags
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD
index b8d1bd415..f2c55588f 100644
--- a/pkg/sentry/syscalls/BUILD
+++ b/pkg/sentry/syscalls/BUILD
@@ -11,6 +11,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/epoll",
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
index 3b4d79889..02debfc7e 100644
--- a/pkg/sentry/syscalls/epoll.go
+++ b/pkg/sentry/syscalls/epoll.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -163,7 +164,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux
}
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return nil, nil
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 408a6c422..a2f612f45 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -64,6 +64,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/bpf",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 6eabfd219..165922332 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -94,13 +95,13 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er
if errno, ok := syserror.TranslateError(errOrig); ok {
translatedErr = errno
}
- switch translatedErr {
- case io.EOF:
+ switch {
+ case translatedErr == io.EOF:
// EOF is always consumed. If this is a partial read/write
// (result != 0), the application will see that, otherwise
// they will see 0.
return true, nil
- case syserror.EFBIG:
+ case linuxerr.Equals(linuxerr.EFBIG, translatedErr):
t := kernel.TaskFromContext(ctx)
if t == nil {
panic("I/O error should only occur from a context associated with a Task")
@@ -113,7 +114,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er
// Simultaneously send a SIGXFSZ per setrlimit(2).
t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
return true, syserror.EFBIG
- case syserror.EINTR:
+ case linuxerr.Equals(linuxerr.EINTR, translatedErr):
// The syscall was interrupted. Return nil if it completed
// partially, otherwise return the error code that the syscall
// needs (to indicate to the kernel what it should do).
@@ -128,21 +129,21 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er
return true, errOrig
}
- switch translatedErr {
- case syserror.EINTR:
+ switch {
+ case linuxerr.Equals(linuxerr.EINTR, translatedErr):
// Syscall interrupted, but completed a partial
// read/write. Like ErrWouldBlock, since we have a
// partial read/write, we consume the error and return
// the partial result.
return true, nil
- case syserror.EFAULT:
+ case linuxerr.Equals(linuxerr.EFAULT, translatedErr):
// EFAULT is only shown the user if nothing was
// read/written. If we read something (this case), they see
// a partial read/write. They will then presumably try again
// with an incremented buffer, which will EFAULT with
// result == 0.
return true, nil
- case syserror.EPIPE:
+ case linuxerr.Equals(linuxerr.EPIPE, translatedErr):
// Writes to a pipe or socket will return EPIPE if the other
// side is gone. The partial write is returned. EPIPE will be
// returned on the next call.
@@ -150,15 +151,17 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er
// TODO(gvisor.dev/issue/161): In some cases SIGPIPE should
// also be sent to the application.
return true, nil
- case syserror.ENOSPC:
+ case linuxerr.Equals(linuxerr.ENOSPC, translatedErr):
// Similar to EPIPE. Return what we wrote this time, and let
// ENOSPC be returned on the next call.
return true, nil
- case syserror.ECONNRESET, syserror.ETIMEDOUT:
+ case linuxerr.Equals(linuxerr.ECONNRESET, translatedErr):
+ fallthrough
+ case linuxerr.Equals(linuxerr.ETIMEDOUT, translatedErr):
// For TCP sendfile connections, we may have a reset or timeout. But we
// should just return n as the result.
return true, nil
- case syserror.EWOULDBLOCK:
+ case linuxerr.Equals(linuxerr.EWOULDBLOCK, translatedErr):
// Syscall would block, but completed a partial read/write.
// This case should only be returned by IssueIO for nonblocking
// files. Since we have a partial read/write, we consume
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 090c5ffcb..1732064ef 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -18,6 +18,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -187,7 +188,7 @@ var AMD64 = &kernel.SyscallTable{
132: syscalls.Supported("utime", Utime),
133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil),
134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil),
- 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 135: syscalls.ErrorWithEvent("personality", linuxerr.EINVAL, "Unable to change personality.", nil),
136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil),
137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
@@ -521,7 +522,7 @@ var ARM64 = &kernel.SyscallTable{
89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
90: syscalls.Supported("capget", Capget),
91: syscalls.Supported("capset", Capset),
- 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 92: syscalls.ErrorWithEvent("personality", linuxerr.EINVAL, "Unable to change personality.", nil),
93: syscalls.Supported("exit", Exit),
94: syscalls.Supported("exit_group", ExitGroup),
95: syscalls.Supported("waitid", Waitid),
diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go
index e8c2d8f9e..9dea78085 100644
--- a/pkg/sentry/syscalls/linux/sigset.go
+++ b/pkg/sentry/syscalls/linux/sigset.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -29,7 +30,7 @@ import (
// syscalls are moved into this package, then they can be unexported.
func CopyInSigSet(t *kernel.Task, sigSetAddr hostarch.Addr, size uint) (linux.SignalSet, error) {
if size != linux.SignalSetSize {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
b := t.CopyScratchBuffer(8)
if _, err := t.CopyInBytes(sigSetAddr, b); err != nil {
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index 70e8569a8..a93fc635b 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -17,6 +17,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -42,7 +43,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, err
}
if idIn != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
id, err := t.MemoryManager().NewAIOContext(t, uint32(nrEvents))
@@ -66,7 +67,7 @@ func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
ctx := t.MemoryManager().DestroyAIOContext(t, id)
if ctx == nil {
// Does not exist.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Drain completed requests amd wait for pending requests until there are no
@@ -97,12 +98,12 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// Sanity check arguments.
if minEvents < 0 || minEvents > events {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ctx, ok := t.MemoryManager().LookupAIOContext(t, id)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Setup the timeout.
@@ -114,7 +115,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
if !d.Valid() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
deadline = t.Kernel().MonotonicClock().Now().Add(d.ToDuration())
haveDeadline = true
@@ -134,7 +135,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
var err error
v, err = waitForRequest(ctx, t, haveDeadline, deadline)
if err != nil {
- if count > 0 || err == syserror.ETIMEDOUT {
+ if count > 0 || linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return uintptr(count), nil, nil
}
return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
@@ -171,7 +172,7 @@ func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadl
done := ctx.WaitChannel()
if done == nil {
// Context has been destroyed.
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
if err := t.BlockWithDeadline(done, haveDeadline, deadline); err != nil {
return nil, err
@@ -184,7 +185,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error)
bytes := int(cb.Bytes)
if bytes < 0 {
// Linux also requires that this field fit in ssize_t.
- return usermem.IOSequence{}, syserror.EINVAL
+ return usermem.IOSequence{}, linuxerr.EINVAL
}
// Since this I/O will be asynchronous with respect to t's task goroutine,
@@ -206,7 +207,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error)
default:
// Not a supported command.
- return usermem.IOSequence{}, syserror.EINVAL
+ return usermem.IOSequence{}, linuxerr.EINVAL
}
}
@@ -286,7 +287,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host
// Check that it is an eventfd.
if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok {
// Not an event FD.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
@@ -299,14 +300,14 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host
switch cb.OpCode {
case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
if cb.Offset < 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
// Prepare the request.
ctx, ok := t.MemoryManager().LookupAIOContext(t, id)
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if err := ctx.Prepare(); err != nil {
return err
@@ -335,7 +336,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
addr := args[2].Pointer()
if nrEvents < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
for i := int32(0); i < nrEvents; i++ {
diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go
index d3b85e11b..782bcb94f 100644
--- a/pkg/sentry/syscalls/linux/sys_capability.go
+++ b/pkg/sentry/syscalls/linux/sys_capability.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -24,7 +25,7 @@ import (
func lookupCaps(t *kernel.Task, tid kernel.ThreadID) (permitted, inheritable, effective auth.CapabilitySet, err error) {
if tid < 0 {
- err = syserror.EINVAL
+ err = linuxerr.EINVAL
return
}
if tid > 0 {
@@ -97,7 +98,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
if dataAddr != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
}
@@ -144,6 +145,6 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if _, err := hdr.CopyOut(t, hdrAddr); err != nil {
return 0, nil, err
}
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 69cbc98d0..daa151bb4 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -31,7 +32,7 @@ import (
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
if flags & ^linux.EPOLL_CLOEXEC != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
closeOnExec := flags&linux.EPOLL_CLOEXEC != 0
@@ -48,7 +49,7 @@ func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
size := args[0].Int()
if size <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
fd, err := syscalls.CreateEpoll(t, false)
@@ -101,7 +102,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
mask |= waiter.EventHUp | waiter.EventErr
return 0, nil, syscalls.UpdateEpoll(t, epfd, fd, flags, mask, data)
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go
index 3b4f879e4..7ba9a755e 100644
--- a/pkg/sentry/syscalls/linux/sys_eventfd.go
+++ b/pkg/sentry/syscalls/linux/sys_eventfd.go
@@ -16,11 +16,11 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/eventfd"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Eventfd2 implements linux syscall eventfd2(2).
@@ -30,7 +30,7 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC)
if flags & ^allOps != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
event := eventfd.New(t, uint64(initVal), flags&linux.EFD_SEMAPHORE != 0)
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 9cd238efd..3d45341f2 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -18,6 +18,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -275,7 +276,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMod
default:
// "EINVAL - mode requested creation of something other than a
// regular file, device special file, FIFO or socket." - mknod(2)
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
})
}
@@ -394,8 +395,8 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode
}
var newFile *fs.File
- switch err {
- case nil:
+ switch {
+ case err == nil:
// Like sys_open, check for a few things about the
// filesystem before trying to get a reference to the
// fs.File. The same constraints on Check apply.
@@ -418,7 +419,7 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode
return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
}
defer newFile.DecRef(t)
- case syserror.ENOENT:
+ case linuxerr.Equals(linuxerr.ENOENT, err):
// File does not exist. Proceed with creation.
// Do we have write permissions on the parent?
@@ -527,7 +528,7 @@ func accessAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode uint) error
// Sanity check the mode.
if mode&^(rOK|wOK|xOK) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
return fileOpOn(t, dirFD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
@@ -843,7 +844,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
flags := args[2].Uint()
if oldfd == newfd {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
oldFile := t.GetFile(oldfd)
@@ -905,7 +906,7 @@ func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error {
if who < 0 {
// Check for overflow before flipping the sign.
if who-1 > who {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who))
a.SetOwnerProcessGroup(t, pg)
@@ -976,7 +977,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case 2:
sw = fs.SeekEnd
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Compute the lock offset.
@@ -995,7 +996,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
off = uattr.Size
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Compute the lock range.
@@ -1043,7 +1044,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
file.Dirent.Inode.LockCtx.Posix.UnlockRegion(t.FDTable(), rng)
return 0, nil, nil
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
case linux.F_GETOWN:
return uintptr(fGetOwn(t, file)), nil, nil
@@ -1085,7 +1086,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
a.SetOwnerProcessGroup(t, pg)
return 0, nil, nil
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
case linux.F_GET_SEALS:
val, err := tmpfs.GetSeals(file.Dirent.Inode)
@@ -1099,14 +1100,14 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETPIPE_SZ:
sz, ok := file.FileOperations.(fs.FifoSizer)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
size, err := sz.FifoSize(t, file)
return uintptr(size), nil, err
case linux.F_SETPIPE_SZ:
sz, ok := file.FileOperations.(fs.FifoSizer)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
n, err := sz.SetFifoSize(int64(args[2].Int()))
return uintptr(n), nil, err
@@ -1118,7 +1119,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
@@ -1131,7 +1132,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// Note: offset is allowed to be negative.
if length < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFile(fd)
@@ -1153,7 +1154,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
case linux.POSIX_FADV_DONTNEED:
case linux.POSIX_FADV_NOREUSE:
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Sure, whatever.
@@ -1178,12 +1179,12 @@ func mkdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMod
// Does this directory exist already?
remainingTraversals := uint(linux.MaxSymlinkTraversals)
f, err := t.MountNamespace().FindInode(t, root, d, name, &remainingTraversals)
- switch err {
- case nil:
+ switch {
+ case err == nil:
// The directory existed.
defer f.DecRef(t)
return syserror.EEXIST
- case syserror.EACCES:
+ case linuxerr.Equals(linuxerr.EACCES, err):
// Permission denied while walking to the directory.
return err
default:
@@ -1236,7 +1237,7 @@ func rmdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr) error {
// dot vs. double dots.
switch name {
case ".":
- return syserror.EINVAL
+ return linuxerr.EINVAL
case "..":
return syserror.ENOTEMPTY
}
@@ -1431,7 +1432,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Sanity check flags.
if flags&^(linux.AT_SYMLINK_FOLLOW|linux.AT_EMPTY_PATH) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
resolve := flags&linux.AT_SYMLINK_FOLLOW == linux.AT_SYMLINK_FOLLOW
@@ -1464,8 +1465,8 @@ func readlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, bufAddr hostarc
}
s, err := d.Inode.Readlink(t)
- if err == syserror.ENOLINK {
- return syserror.EINVAL
+ if linuxerr.Equals(linuxerr.ENOLINK, err) {
+ return linuxerr.EINVAL
}
if err != nil {
return err
@@ -1557,7 +1558,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
length := args[1].Int64()
if length < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
@@ -1565,13 +1566,13 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
if dirPath {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
@@ -1583,7 +1584,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// In contrast to open(O_TRUNC), truncate(2) is only valid for file
// types.
if !fs.IsFile(d.Inode.StableAttr) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Reject truncation if the access permissions do not allow truncation.
@@ -1617,24 +1618,24 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// Reject truncation if the file flags do not permit this operation.
// This is different from truncate(2) above.
if !file.Flags().Write {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// In contrast to open(O_TRUNC), truncate(2) is only valid for file
// types. Note that this is different from truncate(2) above, where a
// directory returns EISDIR.
if !fs.IsFile(file.Dirent.Inode.StableAttr) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if length < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if uint64(length) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
@@ -1673,14 +1674,16 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
if err != nil {
return err
}
+
c := t.Credentials()
hasCap := d.Inode.CheckCapability(t, linux.CAP_CHOWN)
isOwner := uattr.Owner.UID == c.EffectiveKUID
+ var clearPrivilege bool
if uid.Ok() {
kuid := c.UserNamespace.MapToKUID(uid)
// Valid UID must be supplied if UID is to be changed.
if !kuid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// "Only a privileged process (CAP_CHOWN) may change the owner
@@ -1693,13 +1696,18 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
return syserror.EPERM
}
+ // The setuid and setgid bits are cleared during a chown.
+ if uattr.Owner.UID != kuid {
+ clearPrivilege = true
+ }
+
owner.UID = kuid
}
if gid.Ok() {
kgid := c.UserNamespace.MapToKGID(gid)
// Valid GID must be supplied if GID is to be changed.
if !kgid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// "The owner of a file may change the group of the file to any
@@ -1711,6 +1719,11 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
return syserror.EPERM
}
+ // The setuid and setgid bits are cleared during a chown.
+ if uattr.Owner.GID != kgid {
+ clearPrivilege = true
+ }
+
owner.GID = kgid
}
@@ -1721,10 +1734,14 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
if err := d.Inode.SetOwner(t, d, owner); err != nil {
return err
}
+ // Clear privilege bits if needed and they are set.
+ if clearPrivilege && uattr.Perms.HasSetUIDOrGID() && !fs.IsDir(d.Inode.StableAttr) {
+ uattr.Perms.DropSetUIDAndMaybeGID()
+ if !d.Inode.SetPermissions(t, d, uattr.Perms) {
+ return syserror.EPERM
+ }
+ }
- // When the owner or group are changed by an unprivileged user,
- // chown(2) also clears the set-user-ID and set-group-ID bits, but
- // we do not support them.
return nil
}
@@ -1792,7 +1809,7 @@ func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
flags := args[4].Int()
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, chownAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, flags&linux.AT_EMPTY_PATH != 0, uid, gid)
@@ -1897,7 +1914,7 @@ func utimes(t *kernel.Task, dirFD int32, addr hostarch.Addr, ts fs.TimeSpec, res
if addr == 0 && dirFD != linux.AT_FDCWD {
if !resolve {
// Linux returns EINVAL in this case. See utimes.c.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
f := t.GetFile(dirFD)
if f == nil {
@@ -1980,7 +1997,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// If both are UTIME_OMIT, this is a noop.
@@ -2015,7 +2032,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
if times[0].Usec >= 1e6 || times[0].Usec < 0 ||
times[1].Usec >= 1e6 || times[1].Usec < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ts = fs.TimeSpec{
@@ -2101,7 +2118,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
defer file.DecRef(t)
if offset < 0 || length <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if mode != 0 {
t.Kernel().EmitUnimplementedEvent(t)
@@ -2124,9 +2141,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EFBIG
}
if uint64(size) >= t.ThreadGroup().Limits().Get(limits.FileSize).Cur {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
@@ -2191,7 +2208,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
file.Dirent.Inode.LockCtx.BSD.UnlockRegion(file, rng)
default:
// flock(2): EINVAL operation is invalid.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
@@ -2210,7 +2227,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if flags&^memfdAllFlags != 0 {
// Unknown bits in flags.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
allowSeals := flags&linux.MFD_ALLOW_SEALING != 0
@@ -2221,7 +2238,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
if len(name) > memfdMaxNameLen {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
name = memfdPrefix + name
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index eeea1613b..a9e9126cc 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -210,7 +211,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// WAIT_BITSET uses an absolute timeout which is either
// CLOCK_MONOTONIC or CLOCK_REALTIME.
if mask == 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
n, err := futexWaitAbsolute(t, clockRealtime, timespec, forever, addr, private, uint32(val), mask)
return n, nil, err
@@ -224,7 +225,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FUTEX_WAKE_BITSET:
if mask == 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if val <= 0 {
// The Linux kernel wakes one waiter even if val is
@@ -295,7 +296,7 @@ func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
length := args[1].SizeT()
if length != uint(linux.SizeOfRobustListHead) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
t.SetRobustList(head)
return 0, nil, nil
@@ -310,7 +311,7 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
sizeAddr := args[2].Pointer()
if tid < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ot := t
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index bbba71d8f..355fbd766 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -19,6 +19,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -38,7 +39,7 @@ func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
minSize := int(smallestDirent(t.Arch()))
if size < minSize {
// size is smaller than smallest possible dirent.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
n, err := getdents(t, fd, addr, size, (*dirent).Serialize)
@@ -54,7 +55,7 @@ func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
minSize := int(smallestDirent64(t.Arch()))
if size < minSize {
// size is smaller than smallest possible dirent.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
n, err := getdents(t, fd, addr, size, (*dirent).Serialize64)
diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go
index a29d307e5..50fcadb58 100644
--- a/pkg/sentry/syscalls/linux/sys_identity.go
+++ b/pkg/sentry/syscalls/linux/sys_identity.go
@@ -15,10 +15,10 @@
package linux
import (
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -142,7 +142,7 @@ func Setresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
size := int(args[0].Int())
if size < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
kgids := t.Credentials().ExtraKGIDs
// "If size is zero, list is not modified, but the total number of
@@ -151,7 +151,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return uintptr(len(kgids)), nil, nil
}
if size < len(kgids) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
gids := make([]auth.GID, len(kgids))
for i, kgid := range kgids {
@@ -167,7 +167,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
size := args[0].Int()
if size < 0 || size > maxNGroups {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if size == 0 {
return 0, nil, t.SetExtraGIDs(nil)
diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go
index cf47bb9dd..48c8dbdca 100644
--- a/pkg/sentry/syscalls/linux/sys_inotify.go
+++ b/pkg/sentry/syscalls/linux/sys_inotify.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
@@ -30,7 +31,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
flags := int(args[0].Int())
if flags&^allFlags != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
dirent := fs.NewDirent(t, anon.NewInode(t), "inotify")
@@ -72,7 +73,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*fs.Inotify, *fs.File, error) {
if !ok {
// Not an inotify fd.
file.DecRef(t)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
return ino, file, nil
@@ -91,7 +92,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern
// "EINVAL: The given event mask contains no valid events."
// -- inotify_add_watch(2)
if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ino, file, err := fdToInotify(t, fd)
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 0046347cb..c16c63ecc 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -15,6 +15,7 @@
package linux
import (
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -44,7 +45,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case 2:
sw = fs.SeekEnd
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
offset, serr := file.Seek(t, sw, offset)
diff --git a/pkg/sentry/syscalls/linux/sys_membarrier.go b/pkg/sentry/syscalls/linux/sys_membarrier.go
index 63ee5d435..4b67f2536 100644
--- a/pkg/sentry/syscalls/linux/sys_membarrier.go
+++ b/pkg/sentry/syscalls/linux/sys_membarrier.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -29,7 +30,7 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
switch cmd {
case linux.MEMBARRIER_CMD_QUERY:
if flags != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var supportedCommands uintptr
if t.Kernel().Platform.HaveGlobalMemoryBarrier() {
@@ -46,10 +47,10 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return supportedCommands, nil, nil
case linux.MEMBARRIER_CMD_GLOBAL, linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED:
if flags != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if cmd == linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED && !t.MemoryManager().IsMembarrierPrivateEnabled() {
return 0, nil, syserror.EPERM
@@ -57,28 +58,28 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, t.Kernel().Platform.GlobalMemoryBarrier()
case linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED:
if flags != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// no-op
return 0, nil, nil
case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED:
if flags != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
t.MemoryManager().EnableMembarrierPrivate()
return 0, nil, nil
case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ:
if flags&^linux.MEMBARRIER_CMD_FLAG_CPU != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if !t.RSeqAvailable() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if !t.MemoryManager().IsMembarrierRSeqEnabled() {
return 0, nil, syserror.EPERM
@@ -88,16 +89,16 @@ func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, t.Kernel().Platform.PreemptAllCPUs()
case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ:
if flags != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if !t.RSeqAvailable() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
t.MemoryManager().EnableMembarrierRSeq()
return 0, nil, nil
default:
// Probably a command we don't implement.
t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go
index 6d27f4292..62ec3e27f 100644
--- a/pkg/sentry/syscalls/linux/sys_mempolicy.go
+++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -43,7 +44,7 @@ func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64,
// maxnode-1, not maxnode, as the number of bits.
bits := maxnode - 1
if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if bits == 0 {
return 0, nil
@@ -58,12 +59,12 @@ func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64,
// Check that only allowed bits in the first unsigned long in the nodemask
// are set.
if val&^allowedNodemask != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Check that all remaining bits in the nodemask are 0.
for i := 8; i < len(buf); i++ {
if buf[i] != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
return val, nil
@@ -74,7 +75,7 @@ func copyOutNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32, val uin
// bits.
bits := maxnode - 1
if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if bits == 0 {
return nil
@@ -110,7 +111,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
flags := args[4].Uint()
if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
nodeFlag := flags&linux.MPOL_F_NODE != 0
addrFlag := flags&linux.MPOL_F_ADDR != 0
@@ -119,7 +120,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// "EINVAL: The value specified by maxnode is less than the number of node
// IDs supported by the system." - get_mempolicy(2)
if nodemask != 0 && maxnode < maxNodes {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is
@@ -130,7 +131,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either
// MPOL_F_ADDR or MPOL_F_NODE."
if nodeFlag || addrFlag {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil {
return 0, nil, err
@@ -184,7 +185,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will
// just (usually) fail to find a VMA at address 0 and return EFAULT.
if addr != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// "If flags is specified as 0, then information about the calling thread's
@@ -198,7 +199,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
policy, nodemaskVal := t.NumaPolicy()
if nodeFlag {
if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
policy = linux.MPOL_DEFAULT // maxNodes == 1
}
@@ -240,7 +241,7 @@ func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
flags := args[5].Uint()
if flags&^linux.MPOL_MF_VALID != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be
// privileged (CAP_SYS_NICE) to use this flag." - mbind(2)
@@ -264,11 +265,11 @@ func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nod
mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS)
if flags == linux.MPOL_MODE_FLAGS {
// Can't specify both mode flags simultaneously.
- return 0, 0, syserror.EINVAL
+ return 0, 0, linuxerr.EINVAL
}
if mode < 0 || mode >= linux.MPOL_MAX {
// Must specify a valid mode.
- return 0, 0, syserror.EINVAL
+ return 0, 0, linuxerr.EINVAL
}
var nodemaskVal uint64
@@ -285,22 +286,22 @@ func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nod
// "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate;
// Linux allows a nodemask to be specified, as long as it is empty.
if nodemaskVal != 0 {
- return 0, 0, syserror.EINVAL
+ return 0, 0, linuxerr.EINVAL
}
case linux.MPOL_BIND, linux.MPOL_INTERLEAVE:
// These require a non-empty nodemask.
if nodemaskVal == 0 {
- return 0, 0, syserror.EINVAL
+ return 0, 0, linuxerr.EINVAL
}
case linux.MPOL_PREFERRED:
// This permits an empty nodemask, as long as no flags are set.
if nodemaskVal == 0 && flags != 0 {
- return 0, 0, syserror.EINVAL
+ return 0, 0, linuxerr.EINVAL
}
case linux.MPOL_LOCAL:
// This requires an empty nodemask and no flags set ...
if nodemaskVal != 0 || flags != 0 {
- return 0, 0, syserror.EINVAL
+ return 0, 0, linuxerr.EINVAL
}
// ... and is implemented as MPOL_PREFERRED.
mode = linux.MPOL_PREFERRED
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 70da0707d..74279c82b 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -18,13 +18,13 @@ import (
"bytes"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Brk implements linux syscall brk(2).
@@ -51,7 +51,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Require exactly one of MAP_PRIVATE and MAP_SHARED.
if private == shared {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
opts := memmap.MMapOpts{
@@ -132,7 +132,7 @@ func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
newAddr := args[4].Pointer()
if flags&^(linux.MREMAP_MAYMOVE|linux.MREMAP_FIXED) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
mayMove := flags&linux.MREMAP_MAYMOVE != 0
fixed := flags&linux.MREMAP_FIXED != 0
@@ -147,7 +147,7 @@ func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case !mayMove && fixed:
// "If MREMAP_FIXED is specified, then MREMAP_MAYMOVE must also be
// specified." - mremap(2)
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
rv, err := t.MemoryManager().MRemap(t, oldAddr, oldSize, newSize, mm.MRemapOpts{
@@ -178,7 +178,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// "The Linux implementation requires that the address addr be
// page-aligned, and allows length to be zero." - madvise(2)
if addr.RoundDown() != addr {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if length == 0 {
return 0, nil, nil
@@ -186,7 +186,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Not explicitly stated: length need not be page-aligned.
lenAddr, ok := hostarch.Addr(length).RoundUp()
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
length = uint64(lenAddr)
@@ -217,7 +217,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, syserror.EPERM
default:
// If adv is not a valid value tell the caller.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
@@ -228,7 +228,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
vec := args[2].Pointer()
if addr != addr.RoundDown() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// "The length argument need not be a multiple of the page size, but since
// residency information is returned for whole pages, length is effectively
@@ -265,11 +265,11 @@ func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// semantics that are (currently) equivalent to specifying MS_ASYNC." -
// msync(2)
if flags&^(linux.MS_ASYNC|linux.MS_SYNC|linux.MS_INVALIDATE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
sync := flags&linux.MS_SYNC != 0
if sync && flags&linux.MS_ASYNC != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
err := t.MemoryManager().MSync(t, addr, uint64(length), mm.MSyncOpts{
Sync: sync,
@@ -295,7 +295,7 @@ func Mlock2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
flags := args[2].Int()
if flags&^(linux.MLOCK_ONFAULT) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
mode := memmap.MLockEager
@@ -318,7 +318,7 @@ func Mlockall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
flags := args[0].Int()
if flags&^(linux.MCL_CURRENT|linux.MCL_FUTURE|linux.MCL_ONFAULT) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
mode := memmap.MLockEager
diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go
index 864d2138c..8bf4e9f06 100644
--- a/pkg/sentry/syscalls/linux/sys_mount.go
+++ b/pkg/sentry/syscalls/linux/sys_mount.go
@@ -16,12 +16,12 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Mount implements Linux syscall mount(2).
@@ -83,7 +83,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// unknown or unsupported flags are passed. Since we don't implement
// everything, we fail explicitly on flags that are unimplemented.
if flags&(unsupportedOps|unsupportedFlags) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
rsys, ok := fs.FindFilesystem(fsType)
@@ -107,7 +107,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
rootInode, err := rsys.Mount(t, sourcePath, superFlags, data, nil)
if err != nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if err := fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
@@ -130,7 +130,7 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
if flags&unsupported != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index d95034347..5925c2263 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -16,13 +16,13 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/syserror"
)
// LINT.IfChange
@@ -30,7 +30,7 @@ import (
// pipe2 implements the actual system call with flags.
func pipe2(t *kernel.Task, addr hostarch.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize)
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index da548a14a..f2056d850 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -128,7 +129,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
// Wait for a notification.
timeout, err = t.BlockWithTimeout(ch, !forever, timeout)
if err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = nil
}
return timeout, 0, err
@@ -157,7 +158,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
// CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
func CopyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) {
if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
pfd := make([]linux.PollFD, nfds)
@@ -217,7 +218,7 @@ func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialB
func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) {
if nfds < 0 || nfds > fileCap {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Calculate the size of the fd sets (one bit per fd).
@@ -404,7 +405,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
// On an interrupt poll(2) is restarted with the remaining timeout.
- if err == syserror.EINTR {
+ if linuxerr.Equals(linuxerr.EINTR, err) {
t.SetSyscallRestartBlock(&pollRestartBlock{
pfdAddr: pfdAddr,
nfds: nfds,
@@ -463,7 +464,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
//
// Note that this means that if err is nil but copyErr is not, copyErr is
// ignored. This is consistent with Linux.
- if err == syserror.EINTR && copyErr == nil {
+ if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
err = syserror.ERESTARTNOHAND
}
return n, nil, err
@@ -485,7 +486,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
if timeval.Sec < 0 || timeval.Usec < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
timeout = time.Duration(timeval.ToNsecCapped())
}
@@ -493,7 +494,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
// See comment in Ppoll.
- if err == syserror.EINTR && copyErr == nil {
+ if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
err = syserror.ERESTARTNOHAND
}
return n, nil, err
@@ -538,7 +539,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
// See comment in Ppoll.
- if err == syserror.EINTR && copyErr == nil {
+ if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
err = syserror.ERESTARTNOHAND
}
return n, nil, err
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 9890dd946..534f1e632 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -38,7 +39,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.PR_SET_PDEATHSIG:
sig := linux.Signal(args[1].Int())
if sig != 0 && !sig.IsValid() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
t.SetParentDeathSignal(sig)
return 0, nil, nil
@@ -69,7 +70,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
d = mm.UserDumpable
default:
// N.B. Userspace may not pass SUID_DUMP_ROOT.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
t.MemoryManager().SetDumpability(d)
return 0, nil, nil
@@ -90,7 +91,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
} else if val == 1 {
t.SetKeepCaps(true)
} else {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
@@ -98,7 +99,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.PR_SET_NAME:
addr := args[1].Pointer()
name, err := t.CopyInString(addr, linux.TASK_COMM_LEN-1)
- if err != nil && err != syserror.ENAMETOOLONG {
+ if err != nil && !linuxerr.Equals(linuxerr.ENAMETOOLONG, err) {
return 0, nil, err
}
t.SetName(name)
@@ -155,12 +156,12 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
case linux.PR_SET_NO_NEW_PRIVS:
if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// PR_SET_NO_NEW_PRIVS is assumed to always be set.
// See kernel.Task.updateCredsForExecLocked.
@@ -168,7 +169,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.PR_GET_NO_NEW_PRIVS:
if args[1].Int() != 0 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 1, nil, nil
@@ -184,7 +185,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
default:
tracer := t.PIDNamespace().TaskWithID(kernel.ThreadID(pid))
if tracer == nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
t.SetYAMAException(tracer)
return 0, nil, nil
@@ -193,7 +194,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.PR_SET_SECCOMP:
if args[1].Int() != linux.SECCOMP_MODE_FILTER {
// Unsupported mode.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, seccomp(t, linux.SECCOMP_SET_MODE_FILTER, 0, args[2].Pointer())
@@ -204,7 +205,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.PR_CAPBSET_READ:
cp := linux.Capability(args[1].Uint64())
if !cp.Ok() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var rv uintptr
if auth.CapabilitySetOf(cp)&t.Credentials().BoundingCaps != 0 {
@@ -215,7 +216,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.PR_CAPBSET_DROP:
cp := linux.Capability(args[1].Uint64())
if !cp.Ok() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, t.DropBoundingCapability(cp)
@@ -240,7 +241,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go
index ae545f80f..ec6c80de5 100644
--- a/pkg/sentry/syscalls/linux/sys_random.go
+++ b/pkg/sentry/syscalls/linux/sys_random.go
@@ -18,14 +18,14 @@ import (
"io"
"math"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -47,7 +47,7 @@ func GetRandom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// Flags are checked for validity but otherwise ignored. See above.
if flags & ^(_GRND_NONBLOCK|_GRND_RANDOM) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if length > math.MaxInt32 {
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 13e5e3a51..4064467a9 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -58,7 +59,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the destination of the read.
@@ -93,18 +94,18 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// Check that the size is valid.
if int(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Return EINVAL; if the underlying file type does not support readahead,
// then Linux will return EINVAL to indicate as much. In the future, we
// may extend this function to actually support readahead hints.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Pread64 implements linux syscall pread64(2).
@@ -122,7 +123,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is reading at an offset supported?
@@ -138,7 +139,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the destination of the read.
@@ -199,7 +200,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Check that the offset is legitimate.
if offset < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is reading at an offset supported?
@@ -248,7 +249,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Check that the offset is legitimate.
if offset < -1 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is reading at an offset supported?
@@ -331,7 +332,7 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) {
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
index e64246d57..ca78c2ab2 100644
--- a/pkg/sentry/syscalls/linux/sys_rlimit.go
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -129,7 +130,7 @@ func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
resource, ok := limits.FromLinuxResource[int(args[0].Int())]
if !ok {
// Return err; unknown limit.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
addr := args[1].Pointer()
rlim, err := newRlimit(t)
@@ -150,7 +151,7 @@ func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
resource, ok := limits.FromLinuxResource[int(args[0].Int())]
if !ok {
// Return err; unknown limit.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
addr := args[1].Pointer()
rlim, err := newRlimit(t)
@@ -170,7 +171,7 @@ func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
resource, ok := limits.FromLinuxResource[int(args[1].Int())]
if !ok {
// Return err; unknown limit.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
newRlimAddr := args[2].Pointer()
oldRlimAddr := args[3].Pointer()
@@ -185,7 +186,7 @@ func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
if tid < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ot := t
if tid > 0 {
diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go
index 90db10ea6..5fe196647 100644
--- a/pkg/sentry/syscalls/linux/sys_rseq.go
+++ b/pkg/sentry/syscalls/linux/sys_rseq.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -43,6 +44,6 @@ func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return 0, nil, t.ClearRSeq(addr, length, signature)
default:
// Unknown flag.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go
index ac5c98a54..a689abcc9 100644
--- a/pkg/sentry/syscalls/linux/sys_rusage.go
+++ b/pkg/sentry/syscalls/linux/sys_rusage.go
@@ -16,11 +16,11 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/syserror"
)
func getrusage(t *kernel.Task, which int32) linux.Rusage {
@@ -76,7 +76,7 @@ func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
addr := args[1].Pointer()
if which != linux.RUSAGE_SELF && which != linux.RUSAGE_CHILDREN && which != linux.RUSAGE_THREAD {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ru := getrusage(t, which)
diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go
index bfcf44b6f..b5e7b70b5 100644
--- a/pkg/sentry/syscalls/linux/sys_sched.go
+++ b/pkg/sentry/syscalls/linux/sys_sched.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -38,10 +39,10 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
pid := args[0].Int()
param := args[1].Pointer()
if param == 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if pid < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
return 0, nil, syserror.ESRCH
@@ -58,7 +59,7 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
func SchedGetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
pid := args[0].Int()
if pid < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
return 0, nil, syserror.ESRCH
@@ -72,20 +73,20 @@ func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ke
policy := args[1].Int()
param := args[2].Pointer()
if pid < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if policy != onlyScheduler {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if pid != 0 && t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) == nil {
return 0, nil, syserror.ESRCH
}
var r SchedParam
if _, err := r.CopyIn(t, param); err != nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if r.schedPriority != onlyPriority {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
}
diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go
index e16d6ff3f..b0dc84b8d 100644
--- a/pkg/sentry/syscalls/linux/sys_seccomp.go
+++ b/pkg/sentry/syscalls/linux/sys_seccomp.go
@@ -17,10 +17,10 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// userSockFprog is equivalent to Linux's struct sock_fprog on amd64.
@@ -44,7 +44,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error {
// We only support SECCOMP_SET_MODE_FILTER at the moment.
if mode != linux.SECCOMP_SET_MODE_FILTER {
// Unsupported mode.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
tsync := flags&linux.SECCOMP_FILTER_FLAG_TSYNC != 0
@@ -52,7 +52,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error {
// The only flag we support now is SECCOMP_FILTER_FLAG_TSYNC.
if flags&^linux.SECCOMP_FILTER_FLAG_TSYNC != 0 {
// Unsupported flag.
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
var fprog userSockFprog
@@ -66,7 +66,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error {
compiledFilter, err := bpf.Compile(filter)
if err != nil {
t.Debugf("Invalid seccomp-bpf filter: %v", err)
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
return t.AppendSyscallFilter(compiledFilter, tsync)
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index c84260080..d12a4303b 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -61,10 +62,10 @@ func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
nsops := args[2].SizeT()
timespecAddr := args[3].Pointer()
if nsops <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if nsops > opsMax {
- return 0, nil, syserror.E2BIG
+ return 0, nil, linuxerr.E2BIG
}
ops := make([]linux.Sembuf, nsops)
@@ -77,11 +78,11 @@ func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, err
}
if timeout.Sec < 0 || timeout.Nsec < 0 || timeout.Nsec >= 1e9 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if err := semTimedOp(t, id, ops, true, timeout.ToDuration()); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return 0, nil, syserror.EAGAIN
}
return 0, nil, err
@@ -96,10 +97,10 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
nsops := args[2].SizeT()
if nsops <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if nsops > opsMax {
- return 0, nil, syserror.E2BIG
+ return 0, nil, linuxerr.E2BIG
}
ops := make([]linux.Sembuf, nsops)
@@ -113,7 +114,7 @@ func semTimedOp(t *kernel.Task, id int32, ops []linux.Sembuf, haveTimeout bool,
set := t.IPCNamespace().SemaphoreRegistry().FindByID(id)
if set == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
@@ -232,7 +233,7 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return uintptr(semid), nil, err
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
@@ -246,17 +247,17 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
kuid := creds.UserNamespace.MapToKUID(uid)
if !kuid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
kgid := creds.UserNamespace.MapToKGID(gid)
if !kgid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
owner := fs.FileOwner{UID: kuid, GID: kgid}
return set.Change(t, creds, owner, perms)
@@ -266,7 +267,7 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
return set.GetStat(creds)
@@ -276,7 +277,7 @@ func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByIndex(index)
if set == nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
ds, err := set.GetStat(creds)
@@ -289,7 +290,7 @@ func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
func semStatAny(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
set := t.IPCNamespace().SemaphoreRegistry().FindByIndex(index)
if set == nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
ds, err := set.GetStatAny(creds)
@@ -303,7 +304,7 @@ func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
@@ -314,7 +315,7 @@ func setValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
vals := make([]uint16, set.Size())
if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil {
@@ -329,7 +330,7 @@ func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
return set.GetVal(num, creds)
@@ -339,7 +340,7 @@ func getValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
vals, err := set.GetValAll(creds)
@@ -354,7 +355,7 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
gpid, err := set.GetPID(num, creds)
@@ -373,7 +374,7 @@ func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
return set.CountZeroWaiters(num, creds)
@@ -383,7 +384,7 @@ func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
creds := auth.CredentialsFromContext(t)
return set.CountNegativeWaiters(num, creds)
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index 584064143..3e3a952ce 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -16,10 +16,10 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/shm"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Shmget implements shmget(2).
@@ -51,7 +51,7 @@ func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
segment := r.FindByID(id)
if segment == nil {
// No segment with provided id.
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
return segment, nil
}
@@ -64,7 +64,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
segment, err := findSegment(t, id)
if err != nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
defer segment.DecRef(t)
@@ -106,7 +106,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case linux.IPC_STAT:
segment, err := findSegment(t, id)
if err != nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
defer segment.DecRef(t)
@@ -130,7 +130,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Remaining commands refer to a specific segment.
segment, err := findSegment(t, id)
if err != nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
defer segment.DecRef(t)
@@ -155,6 +155,6 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, nil
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 53b12dc41..4d659e5cf 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -84,13 +85,13 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if !mayKill(t, target, sig) {
return 0, nil, syserror.EPERM
}
- info := &arch.SignalInfo{
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(target.PIDNamespace().IDOfTask(t)))
info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
- if err := target.SendGroupSignal(info); err != syserror.ESRCH {
+ if err := target.SendGroupSignal(info); !linuxerr.Equals(linuxerr.ESRCH, err) {
return 0, nil, err
}
}
@@ -123,14 +124,14 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// depend on the iteration order. We at least implement the
// semantics documented by the man page: "On success (at least
// one signal was sent), zero is returned."
- info := &arch.SignalInfo{
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
err := tg.SendSignal(info)
- if err == syserror.ESRCH {
+ if linuxerr.Equals(linuxerr.ESRCH, err) {
// ESRCH is ignored because it means the task
// exited while we were iterating. This is a
// race which would not normally exist on
@@ -167,14 +168,14 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
continue
}
- info := &arch.SignalInfo{
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
}
info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
// See note above regarding ESRCH race above.
- if err := tg.SendSignal(info); err != syserror.ESRCH {
+ if err := tg.SendSignal(info); !linuxerr.Equals(linuxerr.ESRCH, err) {
lastErr = err
}
}
@@ -184,10 +185,10 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
}
}
-func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalInfo {
- info := &arch.SignalInfo{
+func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *linux.SignalInfo {
+ info := &linux.SignalInfo{
Signo: int32(sig),
- Code: arch.SignalInfoTkill,
+ Code: linux.SI_TKILL,
}
info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
@@ -202,7 +203,7 @@ func Tkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// N.B. Inconsistent with man page, linux actually rejects calls with
// tid <=0 by EINVAL. This isn't the same for all signal calls.
if tid <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
target := t.PIDNamespace().TaskWithID(tid)
@@ -225,7 +226,7 @@ func Tgkill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// N.B. Inconsistent with man page, linux actually rejects calls with
// tgid/tid <=0 by EINVAL. This isn't the same for all signal calls.
if tgid <= 0 || tid <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
targetTG := t.PIDNamespace().ThreadGroupWithID(tgid)
@@ -248,23 +249,23 @@ func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
sigsetsize := args[3].SizeT()
if sigsetsize != linux.SignalSetSize {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
- var newactptr *arch.SignalAct
+ var newactptr *linux.SigAction
if newactarg != 0 {
- newact, err := t.CopyInSignalAct(newactarg)
- if err != nil {
+ var newact linux.SigAction
+ if _, err := newact.CopyIn(t, newactarg); err != nil {
return 0, nil, err
}
newactptr = &newact
}
- oldact, err := t.ThreadGroup().SetSignalAct(sig, newactptr)
+ oldact, err := t.ThreadGroup().SetSigAction(sig, newactptr)
if err != nil {
return 0, nil, err
}
if oldactarg != 0 {
- if err := t.CopyOutSignalAct(oldactarg, &oldact); err != nil {
+ if _, err := oldact.CopyOut(t, oldactarg); err != nil {
return 0, nil, err
}
}
@@ -291,7 +292,7 @@ func RtSigprocmask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
sigsetsize := args[3].SizeT()
if sigsetsize != linux.SignalSetSize {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
oldmask := t.SignalMask()
if setaddr != 0 {
@@ -308,7 +309,7 @@ func RtSigprocmask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
case linux.SIG_SETMASK:
t.SetSignalMask(mask)
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
if oldaddr != 0 {
@@ -325,13 +326,12 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
alt := t.SignalStack()
if oldaddr != 0 {
- if err := t.CopyOutSignalStack(oldaddr, &alt); err != nil {
+ if _, err := alt.CopyOut(t, oldaddr); err != nil {
return 0, nil, err
}
}
if setaddr != 0 {
- alt, err := t.CopyInSignalStack(setaddr)
- if err != nil {
+ if _, err := alt.CopyIn(t, setaddr); err != nil {
return 0, nil, err
}
// The signal stack cannot be changed if the task is currently
@@ -378,7 +378,7 @@ func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
return 0, nil, err
}
if !d.Valid() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
timeout = time.Duration(d.ToNsecCapped())
} else {
@@ -410,7 +410,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// We must ensure that the Signo is set (Linux overrides this in the
// same way), and that the code is in the allowed set. This same logic
// appears below in RtSigtgqueueinfo and should be kept in sync.
- var info arch.SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
@@ -426,7 +426,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// If the sender is not the receiver, it can't use si_codes used by the
// kernel or SI_TKILL.
- if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t {
return 0, nil, syserror.EPERM
}
@@ -434,7 +434,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
return 0, nil, syserror.EPERM
}
- if err := target.SendGroupSignal(&info); err != syserror.ESRCH {
+ if err := target.SendGroupSignal(&info); !linuxerr.Equals(linuxerr.ESRCH, err) {
return 0, nil, err
}
}
@@ -450,11 +450,11 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
// N.B. Inconsistent with man page, linux actually rejects calls with
// tgid/tid <=0 by EINVAL. This isn't the same for all signal calls.
if tgid <= 0 || tid <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Copy in the info. See RtSigqueueinfo above.
- var info arch.SignalInfo
+ var info linux.SignalInfo
if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
@@ -469,7 +469,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
// If the sender is not the receiver, it can't use si_codes used by the
// kernel or SI_TKILL.
- if (info.Code >= 0 || info.Code == arch.SignalInfoTkill) && target != t {
+ if (info.Code >= 0 || info.Code == linux.SI_TKILL) && target != t {
return 0, nil, syserror.EPERM
}
@@ -525,7 +525,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u
// Always check for valid flags, even if not creating.
if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is this a change to an existing signalfd?
@@ -545,7 +545,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u
}
// Not a signalfd.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Create a new file.
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index e07917613..6638ad60f 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
@@ -117,7 +118,7 @@ type multipleMessageHeader64 struct {
// from the untrusted address space range.
func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) {
if addrlen > maxAddrLen {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
addrBuf := make([]byte, addrlen)
@@ -139,7 +140,7 @@ func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr h
}
if int32(bufLen) < 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Write the length unconditionally.
@@ -173,7 +174,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Check and initialize the flags.
if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Create the new socket.
@@ -205,7 +206,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Check and initialize the flags.
if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
fileFlags := fs.SettableFileFlags{
@@ -277,7 +278,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) {
// Check that no unsupported flags are passed in.
if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -305,7 +306,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr,
if peerRequested {
// NOTE(magi): Linux does not give you an error if it can't
// write the data back out so neither do we.
- if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL {
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); linuxerr.Equals(linuxerr.EINVAL, err) {
return 0, err
}
}
@@ -421,7 +422,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
switch how {
case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, s.Shutdown(t, int(how)).ToError()
@@ -454,7 +455,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, err
}
if optLen < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Call syscall implementation then copy both value and value len out.
@@ -530,10 +531,10 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if optLen < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if optLen > maxOptLen {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
@@ -612,7 +613,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -630,7 +631,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Reject flags that we don't handle yet.
if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if file.Flags().NonBlocking {
@@ -660,7 +661,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if vlen > linux.UIO_MAXIOV {
@@ -669,7 +670,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Reject flags that we don't handle yet.
if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -697,7 +698,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
if !ts.Valid() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
haveDeadline = true
@@ -829,12 +830,12 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags
// recvfrom and recv syscall handlers.
func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) {
if int(bufLen) < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Reject flags that we don't handle yet.
if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -907,7 +908,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -925,7 +926,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Reject flags that we don't handle yet.
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if file.Flags().NonBlocking {
@@ -945,7 +946,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if vlen > linux.UIO_MAXIOV {
@@ -967,7 +968,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Reject flags that we don't handle yet.
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if file.Flags().NonBlocking {
@@ -1073,7 +1074,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr hostar
func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) {
bl := int(bufLen)
if bl < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Get socket from the file descriptor.
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 134051124..88bee61ef 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -27,7 +28,7 @@ import (
// doSplice implements a blocking splice operation.
func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) {
if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if opts.Length == 0 {
return 0, nil
@@ -125,13 +126,13 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Verify that the outfile Append flag is not set.
if outFile.Flags().Append {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Verify that we have a regular infile. This is a requirement; the
// same check appears in Linux (fs/splice.c:splice_direct_to_actor).
if !fs.IsRegular(inFile.Dirent.Inode.StableAttr) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var (
@@ -190,7 +191,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Check for invalid flags.
if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get files.
@@ -230,7 +231,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
if outOffset != 0 {
if !outFile.Flags().Pwrite {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var offset int64
@@ -248,7 +249,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
if inOffset != 0 {
if !inFile.Flags().Pread {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var offset int64
@@ -267,10 +268,10 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// We may not refer to the same pipe; otherwise it's a continuous loop.
if inFileAttr.InodeID == outFileAttr.InodeID {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Splice data.
@@ -298,7 +299,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
// Check for invalid flags.
if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get files.
@@ -316,12 +317,12 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
// All files must be pipes.
if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// We may not refer to the same pipe; see above.
if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// The operation is non-blocking if anything is non-blocking.
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 2338ba44b..103b13c10 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -139,13 +140,13 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
statxAddr := args[4].Pointer()
if mask&linux.STATX__RESERVED != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if flags&^(linux.AT_SYMLINK_NOFOLLOW|linux.AT_EMPTY_PATH|linux.AT_STATX_SYNC_TYPE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if flags&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
path, dirPath, err := copyInPath(t, pathAddr, flags&linux.AT_EMPTY_PATH != 0)
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 5ebd4461f..3f0e6c02e 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -86,13 +87,13 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
uflags := args[3].Uint()
if offset < 0 || offset+nbytes < offset {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if uflags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|
linux.SYNC_FILE_RANGE_WRITE|
linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if nbytes == 0 {
diff --git a/pkg/sentry/syscalls/linux/sys_syslog.go b/pkg/sentry/syscalls/linux/sys_syslog.go
index 40c8bb061..ba372f9e3 100644
--- a/pkg/sentry/syscalls/linux/sys_syslog.go
+++ b/pkg/sentry/syscalls/linux/sys_syslog.go
@@ -15,6 +15,7 @@
package linux
import (
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -40,7 +41,7 @@ func Syslog(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
switch command {
case _SYSLOG_ACTION_READ_ALL:
if size < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if size > logBufLen {
size = logBufLen
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 3185ea527..d99dd5131 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -19,6 +19,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -112,7 +113,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr host
}
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
atEmptyPath := flags&linux.AT_EMPTY_PATH != 0
if !atEmptyPath && len(pathname) == 0 {
@@ -260,7 +261,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error {
wopts.NonCloneTasks = true
wopts.CloneTasks = true
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if options&linux.WCONTINUED != 0 {
wopts.Events |= kernel.EventGroupContinue
@@ -277,7 +278,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error {
// wait4 waits for the given child process to exit.
func wait4(t *kernel.Task, pid int, statusAddr hostarch.Addr, options int, rusageAddr hostarch.Addr) (uintptr, error) {
if options&^(linux.WNOHANG|linux.WUNTRACED|linux.WCONTINUED|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
wopts := kernel.WaitOptions{
Events: kernel.EventExit | kernel.EventTraceeStop,
@@ -358,10 +359,10 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
rusageAddr := args[4].Pointer()
if options&^(linux.WNOHANG|linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED|linux.WNOWAIT|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if options&(linux.WEXITED|linux.WSTOPPED|linux.WCONTINUED) == 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
wopts := kernel.WaitOptions{
Events: kernel.EventTraceeStop,
@@ -374,7 +375,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case linux.P_PGID:
wopts.SpecificPGID = kernel.ProcessGroupID(id)
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if err := parseCommonWaitOptions(&wopts, options); err != nil {
@@ -398,7 +399,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// out the fields it would set for a successful waitid in this case
// as well.
if infop != 0 {
- var si arch.SignalInfo
+ var si linux.SignalInfo
_, err = si.CopyOut(t, infop)
}
}
@@ -413,7 +414,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if infop == 0 {
return 0, nil, nil
}
- si := arch.SignalInfo{
+ si := linux.SignalInfo{
Signo: int32(linux.SIGCHLD),
}
si.SetPID(int32(wr.TID))
@@ -423,24 +424,24 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
s := unix.WaitStatus(wr.Status)
switch {
case s.Exited():
- si.Code = arch.CLD_EXITED
+ si.Code = linux.CLD_EXITED
si.SetStatus(int32(s.ExitStatus()))
case s.Signaled():
- si.Code = arch.CLD_KILLED
+ si.Code = linux.CLD_KILLED
si.SetStatus(int32(s.Signal()))
case s.CoreDump():
- si.Code = arch.CLD_DUMPED
+ si.Code = linux.CLD_DUMPED
si.SetStatus(int32(s.Signal()))
case s.Stopped():
if wr.Event == kernel.EventTraceeStop {
- si.Code = arch.CLD_TRAPPED
+ si.Code = linux.CLD_TRAPPED
si.SetStatus(int32(s.TrapCause()))
} else {
- si.Code = arch.CLD_STOPPED
+ si.Code = linux.CLD_STOPPED
si.SetStatus(int32(s.StopSignal()))
}
case s.Continued():
- si.Code = arch.CLD_CONTINUED
+ si.Code = linux.CLD_CONTINUED
si.SetStatus(int32(linux.SIGCONT))
default:
t.Warningf("waitid got incomprehensible wait status %d", s)
@@ -528,7 +529,7 @@ func SchedGetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
// in an array of "unsigned long" so the buffer needs to
// be a multiple of the word size.
if size&(t.Arch().Width()-1) > 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var task *kernel.Task
@@ -545,7 +546,7 @@ func SchedGetaffinity(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
// The buffer needs to be big enough to hold a cpumask with
// all possible cpus.
if size < mask.Size() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
_, err := t.CopyOutBytes(maskAddr, mask)
@@ -594,7 +595,7 @@ func Setpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
tg = ot.ThreadGroup()
if tg.Leader() != ot {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Setpgid only operates on child threadgroups.
@@ -609,7 +610,7 @@ func Setpgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if pgid == 0 {
pgid = defaultPGID
} else if pgid < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// If the pgid is the same as the group, then create a new one. Otherwise,
@@ -712,7 +713,7 @@ func Getpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// PRIO_USER and PRIO_PGRP have no further implementation yet.
return 0, nil, nil
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
@@ -754,7 +755,7 @@ func Setpriority(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// PRIO_USER and PRIO_PGRP have no further implementation yet.
return 0, nil, nil
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 83b777bbd..d75bb9c4f 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -75,7 +76,7 @@ func ClockGetres(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
}
if _, err := getClock(t, clockID); err != nil {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if addr == 0 {
@@ -94,12 +95,12 @@ type cpuClocker interface {
func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) {
if clockID < 0 {
if !isValidCPUClock(clockID) {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
targetTask := targetTask(t, clockID)
if targetTask == nil {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
var target cpuClocker
@@ -116,7 +117,7 @@ func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) {
// CPUCLOCK_SCHED is approximated by CPUCLOCK_PROF.
return target.CPUClock(), nil
default:
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
@@ -138,7 +139,7 @@ func getClock(t *kernel.Task, clockID int32) (ktime.Clock, error) {
case linux.CLOCK_THREAD_CPUTIME_ID:
return t.CPUClock(), nil
default:
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
@@ -180,21 +181,21 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
//
// +stateify savable
type clockNanosleepRestartBlock struct {
- c ktime.Clock
- duration time.Duration
- rem hostarch.Addr
+ c ktime.Clock
+ end ktime.Time
+ rem hostarch.Addr
}
// Restart implements kernel.SyscallRestartBlock.Restart.
func (n *clockNanosleepRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
- return 0, clockNanosleepFor(t, n.c, n.duration, n.rem)
+ return 0, clockNanosleepUntil(t, n.c, n.end, n.rem, true)
}
// clockNanosleepUntil blocks until a specified time.
//
// If blocking is interrupted, the syscall is restarted with the original
// arguments.
-func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error {
+func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem hostarch.Addr, needRestartBlock bool) error {
notifier, tchan := ktime.NewChannelNotifier()
timer := ktime.NewTimer(c, notifier)
@@ -202,43 +203,22 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error
timer.Swap(ktime.Setting{
Period: 0,
Enabled: true,
- Next: ktime.FromTimespec(ts),
+ Next: end,
})
err := t.BlockWithTimer(nil, tchan)
timer.Destroy()
- // Did we just block until the timeout happened?
- if err == syserror.ETIMEDOUT {
- return nil
- }
-
- return syserror.ConvertIntr(err, syserror.ERESTARTNOHAND)
-}
-
-// clockNanosleepFor blocks for a specified duration.
-//
-// If blocking is interrupted, the syscall is restarted with the remaining
-// duration timeout.
-func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hostarch.Addr) error {
- timer, start, tchan := ktime.After(c, dur)
-
- err := t.BlockWithTimer(nil, tchan)
-
- after := c.Now()
-
- timer.Destroy()
-
- switch err {
- case syserror.ETIMEDOUT:
+ switch {
+ case linuxerr.Equals(linuxerr.ETIMEDOUT, err):
// Slept for entire timeout.
return nil
- case syserror.ErrInterrupted:
+ case err == syserror.ErrInterrupted:
// Interrupted.
- remaining := dur - after.Sub(start)
- if remaining < 0 {
- remaining = time.Duration(0)
+ remaining := end.Sub(c.Now())
+ if remaining <= 0 {
+ return nil
}
// Copy out remaining time.
@@ -248,14 +228,16 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hos
return err
}
}
-
- // Arrange for a restart with the remaining duration.
- t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
- c: c,
- duration: remaining,
- rem: rem,
- })
- return syserror.ERESTART_RESTARTBLOCK
+ if needRestartBlock {
+ // Arrange for a restart with the remaining duration.
+ t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
+ c: c,
+ end: end,
+ rem: rem,
+ })
+ return syserror.ERESTART_RESTARTBLOCK
+ }
+ return syserror.ERESTARTNOHAND
default:
panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
@@ -272,13 +254,14 @@ func Nanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
if !ts.Valid() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Just like linux, we cap the timeout with the max number that int64 can
// represent which is roughly 292 years.
dur := time.Duration(ts.ToNsecCapped()) * time.Nanosecond
- return 0, nil, clockNanosleepFor(t, t.Kernel().MonotonicClock(), dur, rem)
+ c := t.Kernel().MonotonicClock()
+ return 0, nil, clockNanosleepUntil(t, c, c.Now().Add(dur), rem, true)
}
// ClockNanosleep implements linux syscall clock_nanosleep(2).
@@ -294,7 +277,7 @@ func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
if !req.Valid() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Only allow clock constants also allowed by Linux.
@@ -302,7 +285,7 @@ func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if clockID != linux.CLOCK_REALTIME &&
clockID != linux.CLOCK_MONOTONIC &&
clockID != linux.CLOCK_PROCESS_CPUTIME_ID {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
@@ -312,11 +295,11 @@ func ClockNanosleep(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
if flags&linux.TIMER_ABSTIME != 0 {
- return 0, nil, clockNanosleepUntil(t, c, req)
+ return 0, nil, clockNanosleepUntil(t, c, ktime.FromTimespec(req), 0, false)
}
dur := time.Duration(req.ToNsecCapped()) * time.Nanosecond
- return 0, nil, clockNanosleepFor(t, c, dur, rem)
+ return 0, nil, clockNanosleepUntil(t, c, c.Now().Add(dur), rem, true)
}
// Gettimeofday implements linux syscall gettimeofday(2).
diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go
index cadd9d348..a8e88b814 100644
--- a/pkg/sentry/syscalls/linux/sys_timerfd.go
+++ b/pkg/sentry/syscalls/linux/sys_timerfd.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
@@ -30,7 +31,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
flags := args[1].Int()
if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var c ktime.Clock
@@ -40,7 +41,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME:
c = t.Kernel().MonotonicClock()
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
f := timerfd.NewFile(t, c)
defer f.DecRef(t)
@@ -66,7 +67,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
oldValAddr := args[3].Pointer()
if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
f := t.GetFile(fd)
@@ -77,7 +78,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tf, ok := f.FileOperations.(*timerfd.TimerOperations)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var newVal linux.Itimerspec
@@ -111,7 +112,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tf, ok := f.FileOperations.(*timerfd.TimerOperations)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
tm, s := tf.GetTime()
diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
index 6ddd30d5c..32272c267 100644
--- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go
+++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
@@ -18,6 +18,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -48,7 +49,7 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 66c5974f5..7fffb189e 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -60,7 +61,7 @@ func Setdomainname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.EPERM
}
if size < 0 || size > linux.UTSLen {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
name, err := t.CopyInString(nameAddr, int(size))
@@ -82,7 +83,7 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EPERM
}
if size < 0 || size > linux.UTSLen {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
name := make([]byte, size)
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 95bfe6606..998b5fde6 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -58,7 +59,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the source of the write.
@@ -89,7 +90,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is writing at an offset supported?
@@ -105,7 +106,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the source of the write.
@@ -166,7 +167,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Check that the offset is legitimate.
if offset < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is writing at an offset supported?
@@ -219,7 +220,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Check that the offset is legitimate.
if offset < -1 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is writing at an offset supported?
@@ -301,7 +302,7 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) {
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 28ad6a60e..da6651062 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -18,6 +18,7 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -182,7 +183,7 @@ func setXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink
// setXattr implements setxattr(2) from the given *fs.Dirent.
func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, size uint64, flags uint32) error {
if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
name, err := copyInXattrName(t, nameAddr)
@@ -195,7 +196,7 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, s
}
if size > linux.XATTR_SIZE_MAX {
- return syserror.E2BIG
+ return linuxerr.E2BIG
}
buf := make([]byte, size)
if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
@@ -217,7 +218,7 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, s
func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) {
name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
if err != nil {
- if err == syserror.ENAMETOOLONG {
+ if linuxerr.Equals(linuxerr.ENAMETOOLONG, err) {
return "", syserror.ERANGE
}
return "", err
@@ -333,7 +334,7 @@ func listXattr(t *kernel.Task, d *fs.Dirent, addr hostarch.Addr, size uint64) (i
listSize := xattrListSize(xattrs)
if listSize > linux.XATTR_SIZE_MAX {
- return 0, syserror.E2BIG
+ return 0, linuxerr.E2BIG
}
if uint64(listSize) > requestedSize {
return 0, syserror.ERANGE
diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go
index 3edc922eb..b327e27d6 100644
--- a/pkg/sentry/syscalls/linux/timespec.go
+++ b/pkg/sentry/syscalls/linux/timespec.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -103,7 +104,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time.
return 0, err
}
if !timespec.Valid() {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
timeout = time.Duration(timespec.ToNsecCapped())
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 5ce0bc714..a73f096ff 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -41,6 +41,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/bits",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/hostarch",
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index fd1863ef3..d81df637f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -17,6 +17,8 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
@@ -26,8 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// IoSubmit implements linux syscall io_submit(2).
@@ -37,7 +37,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
addr := args[2].Pointer()
if nrEvents < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
for i := int32(0); i < nrEvents; i++ {
@@ -90,7 +90,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// submitCallback processes a single callback.
func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr hostarch.Addr) error {
if cb.Reserved2 != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
fd := t.GetFileVFS2(cb.FD)
@@ -110,7 +110,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host
// Check that it is an eventfd.
if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
@@ -123,14 +123,14 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr host
switch cb.OpCode {
case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
if cb.Offset < 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
// Prepare the request.
aioCtx, ok := t.MemoryManager().LookupAIOContext(t, id)
if !ok {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if err := aioCtx.Prepare(); err != nil {
return err
@@ -200,7 +200,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error)
bytes := int(cb.Bytes)
if bytes < 0 {
// Linux also requires that this field fit in ssize_t.
- return usermem.IOSequence{}, syserror.EINVAL
+ return usermem.IOSequence{}, linuxerr.EINVAL
}
// Since this I/O will be asynchronous with respect to t's task goroutine,
@@ -222,6 +222,6 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error)
default:
// Not a supported command.
- return usermem.IOSequence{}, syserror.EINVAL
+ return usermem.IOSequence{}, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index 047d955b6..d3bb3a3e1 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -34,7 +35,7 @@ var sizeofEpollEvent = (*linux.EpollEvent)(nil).SizeBytes()
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
if flags&^linux.EPOLL_CLOEXEC != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file, err := t.Kernel().VFS().NewEpollInstanceFD(t)
@@ -59,7 +60,7 @@ func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// "Since Linux 2.6.8, the size argument is ignored, but must be greater
// than zero" - epoll_create(2)
if size <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file, err := t.Kernel().VFS().NewEpollInstanceFD(t)
@@ -89,7 +90,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
defer epfile.DecRef(t)
ep, ok := epfile.Impl().(*vfs.EpollInstance)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFileVFS2(fd)
if file == nil {
@@ -97,7 +98,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
defer file.DecRef(t)
if epfile == file {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var event linux.EpollEvent
@@ -115,14 +116,14 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
return 0, nil, ep.ModifyInterest(file, fd, event)
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) {
var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
epfile := t.GetFileVFS2(epfd)
@@ -132,7 +133,7 @@ func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents i
defer epfile.DecRef(t)
ep, ok := epfile.Impl().(*vfs.EpollInstance)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Allocate space for a few events on the stack for the common case in
@@ -174,7 +175,7 @@ func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents i
haveDeadline = true
}
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = nil
}
return 0, nil, err
diff --git a/pkg/sentry/syscalls/linux/vfs2/eventfd.go b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
index 807f909da..0dcf1fbff 100644
--- a/pkg/sentry/syscalls/linux/vfs2/eventfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
@@ -16,10 +16,10 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Eventfd2 implements linux syscall eventfd2(2).
@@ -29,7 +29,7 @@ func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC)
if flags & ^allOps != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
vfsObj := t.Kernel().VFS()
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
index 3315398a4..7b1e1da78 100644
--- a/pkg/sentry/syscalls/linux/vfs2/execve.go
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -16,7 +16,9 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -24,8 +26,6 @@ import (
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Execve implements linux syscall execve(2).
@@ -48,7 +48,7 @@ func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr hostarch.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 1a31898e8..ea34ff471 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
@@ -86,7 +87,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
flags := args[2].Uint()
if oldfd == newfd {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return dup3(t, oldfd, newfd, flags)
@@ -94,7 +95,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
if flags&^linux.O_CLOEXEC != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFileVFS2(oldfd)
@@ -169,7 +170,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if who < 0 {
// Check for overflow before flipping the sign.
if who-1 > who {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ownerType = linux.F_OWNER_PGRP
who = -who
@@ -232,7 +233,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
@@ -269,7 +270,7 @@ func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType,
case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
// Acceptable type.
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync)
@@ -301,7 +302,7 @@ func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType,
a.SetOwnerProcessGroup(t, pg)
return nil
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
@@ -319,7 +320,7 @@ func posixTestLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDes
case linux.F_WRLCK:
typ = lock.WriteLock
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence)
if err != nil {
@@ -382,7 +383,7 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip
return file.UnlockPOSIX(t, t.FDTable(), r)
default:
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
@@ -395,7 +396,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// Note: offset is allowed to be negative.
if length < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFileVFS2(fd)
@@ -421,7 +422,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
case linux.POSIX_FADV_DONTNEED:
case linux.POSIX_FADV_NOREUSE:
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Sure, whatever.
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index 36aa1d3ae..534355237 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -16,12 +16,12 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Link implements Linux syscall link(2).
@@ -43,7 +43,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
func linkat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd int32, newpathAddr hostarch.Addr, flags int32) error {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
return syserror.ENOENT
@@ -290,7 +290,7 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
flags := args[2].Int()
if flags&^linux.AT_REMOVEDIR != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if flags&linux.AT_REMOVEDIR != 0 {
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
index b41a3056a..8ace31af3 100644
--- a/pkg/sentry/syscalls/linux/vfs2/getdents.go
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -17,13 +17,13 @@ package vfs2
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Getdents implements Linux syscall getdents(2).
@@ -100,7 +100,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
size = (size + 7) &^ 7 // round up to multiple of 8
if size > cb.remaining {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
@@ -134,7 +134,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
size = (size + 7) &^ 7 // round up to multiple of sizeof(long)
if size > cb.remaining {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go
index 11753d8e5..7a2e9e75d 100644
--- a/pkg/sentry/syscalls/linux/vfs2/inotify.go
+++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -28,7 +29,7 @@ const allFlags = linux.IN_NONBLOCK | linux.IN_CLOEXEC
func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
if flags&^allFlags != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ino, err := vfs.NewInotifyFD(t, t.Kernel().VFS(), uint32(flags))
@@ -67,7 +68,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription,
if !ok {
// Not an inotify fd.
f.DecRef(t)
- return nil, nil, syserror.EINVAL
+ return nil, nil, linuxerr.EINVAL
}
return ino, f, nil
@@ -82,7 +83,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern
// "EINVAL: The given event mask contains no valid events."
// -- inotify_add_watch(2)
if mask&linux.ALL_INOTIFY_BITS == 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link."
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index c7c3fed57..9852e3fe4 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -99,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if who < 0 {
// Check for overflow before flipping the sign.
if who-1 > who {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
ownerType = linux.F_OWNER_PGRP
who = -who
diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go
index d1452a04d..80cb3ba09 100644
--- a/pkg/sentry/syscalls/linux/vfs2/lock.go
+++ b/pkg/sentry/syscalls/linux/vfs2/lock.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -57,7 +58,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
default:
// flock(2): EINVAL operation is invalid.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go
index c4c0f9e0a..70c2cf5a5 100644
--- a/pkg/sentry/syscalls/linux/vfs2/memfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go
@@ -16,10 +16,10 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -35,7 +35,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if flags&^memfdAllFlags != 0 {
// Unknown bits in flags.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
allowSeals := flags&linux.MFD_ALLOW_SEALING != 0
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
index c961545f6..db8d59899 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mmap.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -16,13 +16,13 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Mmap implements Linux syscall mmap(2).
@@ -38,7 +38,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Require exactly one of MAP_PRIVATE and MAP_SHARED.
if private == shared {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
opts := memmap.MMapOpts{
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
index dd93430e2..667e48744 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mount.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -16,12 +16,12 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Mount implements Linux syscall mount(2).
@@ -84,7 +84,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// unknown or unsupported flags are passed. Since we don't implement
// everything, we fail explicitly on flags that are unimplemented.
if flags&(unsupportedOps|unsupportedFlags) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var opts vfs.MountOptions
@@ -130,7 +130,7 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
if flags&unsupported != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
path, err := copyInPath(t, addr)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
index c6fc1954c..07a89cf4e 100644
--- a/pkg/sentry/syscalls/linux/vfs2/pipe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -16,14 +16,13 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Pipe implements Linux syscall pipe(2).
@@ -41,7 +40,7 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
func pipe2(t *kernel.Task, addr hostarch.Addr, flags int32) error {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
r, w, err := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index a69c80edd..ea95dd78c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -132,7 +133,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
// Wait for a notification.
timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
if err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = nil
}
return timeout, 0, err
@@ -161,7 +162,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
func copyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) {
if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
pfd := make([]linux.PollFD, nfds)
@@ -221,7 +222,7 @@ func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialB
func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) {
if nfds < 0 || nfds > fileCap {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Calculate the size of the fd sets (one bit per fd).
@@ -410,7 +411,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
// On an interrupt poll(2) is restarted with the remaining timeout.
- if err == syserror.EINTR {
+ if linuxerr.Equals(linuxerr.EINTR, err) {
t.SetSyscallRestartBlock(&pollRestartBlock{
pfdAddr: pfdAddr,
nfds: nfds,
@@ -462,7 +463,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
//
// Note that this means that if err is nil but copyErr is not, copyErr is
// ignored. This is consistent with Linux.
- if err == syserror.EINTR && copyErr == nil {
+ if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
err = syserror.ERESTARTNOHAND
}
return n, nil, err
@@ -484,7 +485,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
if timeval.Sec < 0 || timeval.Usec < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
timeout = time.Duration(timeval.ToNsecCapped())
}
@@ -492,7 +493,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
// See comment in Ppoll.
- if err == syserror.EINTR && copyErr == nil {
+ if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
err = syserror.ERESTARTNOHAND
}
return n, nil, err
@@ -539,7 +540,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
// See comment in Ppoll.
- if err == syserror.EINTR && copyErr == nil {
+ if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
err = syserror.ERESTARTNOHAND
}
return n, nil, err
@@ -561,7 +562,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time.
return 0, err
}
if !timespec.Valid() {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
timeout = time.Duration(timespec.ToNsecCapped())
}
@@ -573,7 +574,7 @@ func setTempSignalSet(t *kernel.Task, maskAddr hostarch.Addr, maskSize uint) err
return nil
}
if maskSize != linux.SignalSetSize {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
var mask linux.SignalSet
if _, err := mask.CopyIn(t, maskAddr); err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
index b863d7b84..3e515f6fd 100644
--- a/pkg/sentry/syscalls/linux/vfs2/read_write.go
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -49,7 +50,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the destination of the read.
@@ -120,7 +121,7 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
@@ -146,13 +147,13 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the destination of the read.
@@ -183,7 +184,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Check that the offset is legitimate.
if offset < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the destination of the read.
@@ -221,7 +222,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Check that the offset is legitimate.
if offset < -1 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the destination of the read.
@@ -275,7 +276,7 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
@@ -300,7 +301,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the source of the write.
@@ -371,7 +372,7 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
@@ -396,13 +397,13 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Check that the size is legitimate.
si := int(size)
if si < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the source of the write.
@@ -433,7 +434,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Check that the offset is legitimate.
if offset < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the source of the write.
@@ -471,7 +472,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Check that the offset is legitimate.
if offset < -1 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get the source of the write.
@@ -525,7 +526,7 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
+ if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
err = syserror.ErrWouldBlock
}
break
@@ -587,16 +588,16 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// Check that the size is valid.
if int(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Check that the offset is legitimate and does not overflow.
if offset < 0 || offset+int64(size) < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Return EINVAL; if the underlying file type does not support readahead,
// then Linux will return EINVAL to indicate as much. In the future, we
// may extend this function to actually support readahead hints.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index c6330c21a..0fbafd6f6 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -16,15 +16,15 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
@@ -105,7 +105,7 @@ func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func fchownat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, owner, group, flags int32) error {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
path, err := copyInPath(t, pathAddr)
@@ -126,7 +126,7 @@ func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vf
if owner != -1 {
kuid := userns.MapToKUID(auth.UID(owner))
if !kuid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
opts.Stat.Mask |= linux.STATX_UID
opts.Stat.UID = uint32(kuid)
@@ -134,7 +134,7 @@ func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vf
if group != -1 {
kgid := userns.MapToKGID(auth.GID(group))
if !kgid.Ok() {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
opts.Stat.Mask |= linux.STATX_GID
opts.Stat.GID = uint32(kgid)
@@ -167,7 +167,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
length := args[1].Int64()
if length < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
path, err := copyInPath(t, addr)
@@ -191,7 +191,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
length := args[1].Int64()
if length < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFileVFS2(fd)
@@ -201,7 +201,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
defer file.DecRef(t)
if !file.IsWritable() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
err := file.SetStat(t, vfs.SetStatOptions{
@@ -233,7 +233,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.ENOTSUP
}
if offset < 0 || length <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
size := offset + length
@@ -242,9 +242,9 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
limit := limits.FromContext(t).Get(limits.FileSize).Cur
if uint64(size) >= limit {
- t.SendSignal(&arch.SignalInfo{
+ t.SendSignal(&linux.SignalInfo{
Signo: int32(linux.SIGXFSZ),
- Code: arch.SignalInfoUser,
+ Code: linux.SI_USER,
})
return 0, nil, syserror.EFBIG
}
@@ -340,7 +340,7 @@ func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr hostarch.Addr, op
return err
}
if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
opts.Stat.Atime = linux.StatxTimestamp{
@@ -372,7 +372,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// "If filename is NULL and dfd refers to an open file, then operate on the
@@ -405,7 +405,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr hostarch.Addr, o
}
if times[0].Nsec != linux.UTIME_OMIT {
if times[0].Nsec != linux.UTIME_NOW && (times[0].Nsec < 0 || times[0].Nsec > 999999999) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
opts.Stat.Mask |= linux.STATX_ATIME
opts.Stat.Atime = linux.StatxTimestamp{
@@ -415,7 +415,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr hostarch.Addr, o
}
if times[1].Nsec != linux.UTIME_OMIT {
if times[1].Nsec != linux.UTIME_NOW && (times[1].Nsec < 0 || times[1].Nsec > 999999999) {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
opts.Stat.Mask |= linux.STATX_MTIME
opts.Stat.Mtime = linux.StatxTimestamp{
diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go
index 6163da103..8b219cba7 100644
--- a/pkg/sentry/syscalls/linux/vfs2/signal.go
+++ b/pkg/sentry/syscalls/linux/vfs2/signal.go
@@ -16,13 +16,13 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/signalfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// sharedSignalfd is shared between the two calls.
@@ -35,7 +35,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u
// Always check for valid flags, even if not creating.
if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Is this a change to an existing signalfd?
@@ -55,7 +55,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize u
}
// Not a signalfd.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
fileFlags := uint32(linux.O_RDWR)
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 69f69e3af..c78c7d951 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -117,7 +118,7 @@ type multipleMessageHeader64 struct {
// from the untrusted address space range.
func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) {
if addrlen > maxAddrLen {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
addrBuf := make([]byte, addrlen)
@@ -139,7 +140,7 @@ func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr h
}
if int32(bufLen) < 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Write the length unconditionally.
@@ -173,7 +174,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Check and initialize the flags.
if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Create the new socket.
@@ -206,7 +207,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Check and initialize the flags.
if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Create the socket pair.
@@ -281,7 +282,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) {
// Check that no unsupported flags are passed in.
if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -309,7 +310,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr,
if peerRequested {
// NOTE(magi): Linux does not give you an error if it can't
// write the data back out so neither do we.
- if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL {
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); linuxerr.Equals(linuxerr.EINVAL, err) {
return 0, err
}
}
@@ -425,7 +426,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
switch how {
case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
return 0, nil, s.Shutdown(t, int(how)).ToError()
@@ -458,7 +459,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, err
}
if optLen < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Call syscall implementation then copy both value and value len out.
@@ -534,10 +535,10 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if optLen < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if optLen > maxOptLen {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
@@ -616,7 +617,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -634,7 +635,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Reject flags that we don't handle yet.
if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
@@ -664,7 +665,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if vlen > linux.UIO_MAXIOV {
@@ -673,7 +674,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Reject flags that we don't handle yet.
if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -701,7 +702,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
if !ts.Valid() {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
haveDeadline = true
@@ -833,12 +834,12 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, fl
// recvfrom and recv syscall handlers.
func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) {
if int(bufLen) < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Reject flags that we don't handle yet.
if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -911,7 +912,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get socket from the file descriptor.
@@ -929,7 +930,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Reject flags that we don't handle yet.
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
@@ -949,7 +950,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if t.Arch().Width() != 8 {
// We only handle 64-bit for now.
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if vlen > linux.UIO_MAXIOV {
@@ -971,7 +972,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Reject flags that we don't handle yet.
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
@@ -1077,7 +1078,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) {
bl := int(bufLen)
if bl < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Get socket from the file descriptor.
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 19e175203..6ddc72999 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -18,6 +18,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -46,12 +47,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
count = int64(kernel.MAX_RW_COUNT)
}
if count < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Check for invalid flags.
if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get file descriptions.
@@ -82,7 +83,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
if !inIsPipe && !outIsPipe {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Copy in offsets.
@@ -92,13 +93,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ESPIPE
}
if inFile.Options().DenyPRead {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if _, err := primitive.CopyInt64In(t, inOffsetPtr, &inOffset); err != nil {
return 0, nil, err
}
if inOffset < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
outOffset := int64(-1)
@@ -107,13 +108,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ESPIPE
}
if outFile.Options().DenyPWrite {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if _, err := primitive.CopyInt64In(t, outOffsetPtr, &outOffset); err != nil {
return 0, nil, err
}
if outOffset < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
@@ -189,12 +190,12 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
count = int64(kernel.MAX_RW_COUNT)
}
if count < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Check for invalid flags.
if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Get file descriptions.
@@ -225,7 +226,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
if !inIsPipe || !outIsPipe {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Copy data.
@@ -288,7 +289,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Verify that the outFile Append flag is not set.
if outFile.StatusFlags()&linux.O_APPEND != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Verify that inFile is a regular file or block device. This is a
@@ -298,7 +299,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
} else if stat.Mask&linux.STATX_TYPE == 0 ||
(stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Copy offset if it exists.
@@ -314,16 +315,16 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
offset = int64(offsetP)
if offset < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if offset+count < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
}
// Validate count. This must come after offset checks.
if count < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if count == 0 {
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
index 69e77fa99..8a22ed8a5 100644
--- a/pkg/sentry/syscalls/linux/vfs2/stat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -17,15 +17,15 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// Stat implements Linux syscall stat(2).
@@ -53,7 +53,7 @@ func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr hostarch.Addr, flags int32) error {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
opts := vfs.StatOptions{
@@ -156,15 +156,15 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
statxAddr := args[4].Pointer()
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW|linux.AT_STATX_SYNC_TYPE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Make sure that only one sync type option is set.
syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE)
if syncType != 0 && !bits.IsPowerOfTwo32(syncType) {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if mask&linux.STATX__RESERVED != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
opts := vfs.StatOptions{
@@ -272,7 +272,7 @@ func accessAt(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, mode uint) er
// Sanity check the mode.
if mode&^(rOK|wOK|xOK) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
path, err := copyInPath(t, pathAddr)
@@ -315,7 +315,7 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr hostarch.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
if int(size) <= 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
path, err := copyInPath(t, pathAddr)
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
index 1f8a5878c..9344a81ce 100644
--- a/pkg/sentry/syscalls/linux/vfs2/sync.go
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -71,10 +72,10 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
// Check for negative values and overflow.
if offset < 0 || offset+nbytes < 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFileVFS2(fd)
diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
index 250870c03..0794330c6 100644
--- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/timerfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -29,7 +30,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
flags := args[1].Int()
if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
// Timerfds aren't writable per se (their implementation of Write just
@@ -47,7 +48,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME:
clock = t.Kernel().MonotonicClock()
default:
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
vfsObj := t.Kernel().VFS()
file, err := timerfd.New(t, vfsObj, clock, fileFlags)
@@ -72,7 +73,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
oldValAddr := args[3].Pointer()
if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFileVFS2(fd)
@@ -83,7 +84,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
var newVal linux.Itimerspec
@@ -117,7 +118,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
if !ok {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
tm, s := tfd.GetTime()
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
index c261050c6..33209a8d0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/xattr.go
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -18,6 +18,7 @@ import (
"bytes"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -178,7 +179,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli
flags := args[4].Int()
if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
path, err := copyInPath(t, pathAddr)
@@ -216,7 +217,7 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
flags := args[4].Int()
if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
- return 0, nil, syserror.EINVAL
+ return 0, nil, linuxerr.EINVAL
}
file := t.GetFileVFS2(fd)
@@ -295,7 +296,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) {
name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
if err != nil {
- if err == syserror.ENAMETOOLONG {
+ if linuxerr.Equals(linuxerr.ENAMETOOLONG, err) {
return "", syserror.ERANGE
}
return "", err
@@ -321,7 +322,7 @@ func copyOutXattrNameList(t *kernel.Task, listAddr hostarch.Addr, size uint, nam
}
if buf.Len() > int(size) {
if size >= linux.XATTR_LIST_MAX {
- return 0, syserror.E2BIG
+ return 0, linuxerr.E2BIG
}
return 0, syserror.ERANGE
}
@@ -330,7 +331,7 @@ func copyOutXattrNameList(t *kernel.Task, listAddr hostarch.Addr, size uint, nam
func copyInXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint) (string, error) {
if size > linux.XATTR_SIZE_MAX {
- return "", syserror.E2BIG
+ return "", linuxerr.E2BIG
}
buf := make([]byte, size)
if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
@@ -349,7 +350,7 @@ func copyOutXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint, value
}
if len(value) > int(size) {
if size >= linux.XATTR_SIZE_MAX {
- return 0, syserror.E2BIG
+ return 0, linuxerr.E2BIG
}
return 0, syserror.ERANGE
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 1f617ca8f..36d999c47 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "seqatomic_parameters_unsafe.go",
package = "time",
suffix = "Parameters",
- template = "//pkg/sync:generic_seqatomic",
+ template = "//pkg/sync/seqatomic:generic_seqatomic",
types = {
"Value": "Parameters",
},
@@ -25,6 +25,8 @@ go_library(
"muldiv_arm64.s",
"parameters.go",
"sampler.go",
+ "sampler_amd64.go",
+ "sampler_arm64.go",
"sampler_unsafe.go",
"seqatomic_parameters_unsafe.go",
"tsc_amd64.s",
@@ -32,6 +34,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/errors/linuxerr",
"//pkg/gohacks",
"//pkg/log",
"//pkg/metric",
diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go
index 39bf1e0de..eed74f6bd 100644
--- a/pkg/sentry/time/calibrated_clock.go
+++ b/pkg/sentry/time/calibrated_clock.go
@@ -19,10 +19,10 @@ package time
import (
"time"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// CalibratedClock implements a clock that tracks a reference clock.
@@ -259,6 +259,6 @@ func (c *CalibratedClocks) GetTime(id ClockID) (int64, error) {
case Realtime:
return c.realtime.GetTime()
default:
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/time/sampler.go b/pkg/sentry/time/sampler.go
index 4ac9c4474..24a47f5d5 100644
--- a/pkg/sentry/time/sampler.go
+++ b/pkg/sentry/time/sampler.go
@@ -21,13 +21,6 @@ import (
)
const (
- // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
- // It is further refined as syscalls are made.
- defaultOverheadCycles = 1 * 1000
-
- // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles.
- maxOverheadCycles = 100 * defaultOverheadCycles
-
// maxSampleLoops is the maximum number of times to try to get a clock sample
// under the expected overhead.
maxSampleLoops = 5
diff --git a/pkg/sentry/time/sampler_amd64.go b/pkg/sentry/time/sampler_amd64.go
new file mode 100644
index 000000000..9f1b4b2fb
--- /dev/null
+++ b/pkg/sentry/time/sampler_amd64.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//+build amd64
+
+package time
+
+const (
+ // defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
+ // It is further refined as syscalls are made.
+ defaultOverheadCycles = 1 * 1000
+
+ // maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles.
+ maxOverheadCycles = 100 * defaultOverheadCycles
+)
diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go
new file mode 100644
index 000000000..4c8d33ae4
--- /dev/null
+++ b/pkg/sentry/time/sampler_arm64.go
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//+build arm64
+
+package time
+
+// getCNTFRQ get ARM counter-timer frequency
+func getCNTFRQ() TSCValue
+
+// getDefaultArchOverheadCycles get default OverheadCycles based on
+// ARM counter-timer frequency. Usually ARM counter-timer frequency
+// is range from 1-50Mhz which is much less than that on x86, so we
+// calibrate defaultOverheadCycles for ARM.
+func getDefaultArchOverheadCycles() TSCValue {
+ // estimated the clock frequency on x86 is 1Ghz.
+ // 1Ghz devided by counter-timer frequency of ARM to get
+ // frqRatio. defaultOverheadCycles of ARM equals to that on
+ // x86 devided by frqRatio
+ cntfrq := getCNTFRQ()
+ frqRatio := 1000000000 / cntfrq
+ overheadCycles := (1 * 1000) / frqRatio
+ return overheadCycles
+}
+
+// defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
+// It is further refined as syscalls are made.
+var defaultOverheadCycles = getDefaultArchOverheadCycles()
+
+// maxOverheadCycles is the maximum allowed syscall overhead in TSC cycles.
+var maxOverheadCycles = 100 * defaultOverheadCycles
diff --git a/pkg/sentry/time/tsc_arm64.s b/pkg/sentry/time/tsc_arm64.s
index da9fa4112..711349fa1 100644
--- a/pkg/sentry/time/tsc_arm64.s
+++ b/pkg/sentry/time/tsc_arm64.s
@@ -20,3 +20,9 @@ TEXT ·Rdtsc(SB),NOSPLIT,$0-8
WORD $0xd53be040 //MRS CNTVCT_EL0, R0
MOVD R0, ret+0(FP)
RET
+
+TEXT ·getCNTFRQ(SB),NOSPLIT,$0-8
+ // Get the virtual counter frequency.
+ WORD $0xd53be000 //MRS CNTFRQ_EL0, R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index 581862ee2..e7073ec87 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -132,7 +132,7 @@ func Init() error {
// always be the case for a newly mapped page from /dev/shm. If we obtain
// the shared memory through some other means in the future, we may have to
// explicitly zero the page.
- mmap, err := unix.Mmap(int(file.Fd()), 0, int(RTMemoryStatsSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED)
+ mmap, err := memutil.MapFile(0, RTMemoryStatsSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED, file.Fd(), 0)
if err != nil {
return fmt.Errorf("error mapping usage file: %v", err)
}
diff --git a/pkg/sentry/usage/memory_unsafe.go b/pkg/sentry/usage/memory_unsafe.go
index 9e0014ca0..bc1531b91 100644
--- a/pkg/sentry/usage/memory_unsafe.go
+++ b/pkg/sentry/usage/memory_unsafe.go
@@ -21,7 +21,7 @@ import (
// RTMemoryStatsSize is the size of the RTMemoryStats struct.
var RTMemoryStatsSize = unsafe.Sizeof(RTMemoryStats{})
-// RTMemoryStatsPointer casts the address of the byte slice into a RTMemoryStats pointer.
-func RTMemoryStatsPointer(b []byte) *RTMemoryStats {
- return (*RTMemoryStats)(unsafe.Pointer(&b[0]))
+// RTMemoryStatsPointer casts addr to a RTMemoryStats pointer.
+func RTMemoryStatsPointer(addr uintptr) *RTMemoryStats {
+ return (*RTMemoryStats)(unsafe.Pointer(addr))
}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index ac60fe8bf..a2032162d 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -95,6 +95,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/fspath",
@@ -133,6 +134,7 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/contexttest",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index f48817132..8b3612200 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -164,7 +165,7 @@ func (fs *anonFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (st
if !rp.Done() {
return "", syserror.ENOTDIR
}
- return "", syserror.EINVAL
+ return "", linuxerr.EINVAL
}
// RenameAt implements FilesystemImpl.RenameAt.
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index ef8d8a813..63ee0aab3 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
@@ -273,7 +274,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede
}
}
if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// TODO(gvisor.dev/issue/1035): FileDescriptionImpl.SetOAsync()?
const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK
@@ -708,8 +709,8 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string
return names, err
}
names, err := fd.impl.ListXattr(ctx, size)
- if err == syserror.ENOTSUP {
- // Linux doesn't actually return ENOTSUP in this case; instead,
+ if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) {
+ // Linux doesn't actually return EOPNOTSUPP in this case; instead,
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
// subsystem to return security extended attributes, which by default
// don't exist.
@@ -873,7 +874,7 @@ func (fd *FileDescription) ComputeLockRange(ctx context.Context, start uint64, l
}
off = int64(stat.Size)
default:
- return lock.LockRange{}, syserror.EINVAL
+ return lock.LockRange{}, linuxerr.EINVAL
}
return lock.ComputeRange(int64(start), int64(length), off)
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 2b6f47b4b..cffb46aab 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -88,25 +89,25 @@ func (FileDescriptionDefaultImpl) EventUnregister(e *waiter.Entry) {
// PRead implements FileDescriptionImpl.PRead analogously to
// file_operations::read == file_operations::read_iter == NULL in Linux.
func (FileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Read implements FileDescriptionImpl.Read analogously to
// file_operations::read == file_operations::read_iter == NULL in Linux.
func (FileDescriptionDefaultImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// PWrite implements FileDescriptionImpl.PWrite analogously to
// file_operations::write == file_operations::write_iter == NULL in Linux.
func (FileDescriptionDefaultImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Write implements FileDescriptionImpl.Write analogously to
// file_operations::write == file_operations::write_iter == NULL in Linux.
func (FileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// IterDirents implements FileDescriptionImpl.IterDirents analogously to
@@ -125,7 +126,7 @@ func (FileDescriptionDefaultImpl) Seek(ctx context.Context, offset int64, whence
// Sync implements FileDescriptionImpl.Sync analogously to
// file_operations::fsync == NULL in Linux.
func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to
@@ -333,10 +334,10 @@ func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int6
offset += fd.off
default:
// fs/seq_file:seq_lseek() rejects SEEK_END etc.
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
if offset != fd.lastRead {
// Regenerate the file's contents immediately. Compare
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 1cd607c0a..566ad856a 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -155,10 +156,10 @@ func TestGenCountFD(t *testing.T) {
}
// Write and PWrite fails.
- if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO {
+ if _, err := fd.Write(ctx, ioseq, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) {
t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
}
- if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO {
+ if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) {
t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
}
}
@@ -215,10 +216,10 @@ func TestWritable(t *testing.T) {
if n, err := fd.Seek(ctx, 1, linux.SEEK_SET); n != 0 && err != nil {
t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err)
}
- if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && err != syserror.EINVAL {
+ if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) {
t.Errorf("Write: got err (%v, %v), wanted (0, EINVAL)", n, err)
}
- if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && err != syserror.EINVAL {
+ if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && !linuxerr.Equals(linuxerr.EINVAL, err) {
t.Errorf("PWrite: got err (%v, %v), wanted (0, EINVAL)", n, err)
}
}
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 49d29e20b..a7655bbb5 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
@@ -98,7 +99,7 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32)
// O_CLOEXEC affects file descriptors, so it must be handled outside of vfs.
flags &^= linux.O_CLOEXEC
if flags&^linux.O_NONBLOCK != 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
id := uniqueid.GlobalFromContext(ctx)
@@ -200,7 +201,7 @@ func (*Inotify) Write(ctx context.Context, src usermem.IOSequence, opts WriteOpt
// Read implements FileDescriptionImpl.Read.
func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
if dst.NumBytes() < inotifyEventBaseSize {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
i.evMu.Lock()
@@ -226,7 +227,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt
// write some events out.
return writeLen, nil
}
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
// Linux always dequeues an available event as long as there's enough
@@ -360,7 +361,7 @@ func (i *Inotify) RmWatch(ctx context.Context, wd int32) error {
w, ok := i.watches[wd]
if !ok {
i.mu.Unlock()
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Remove the watch from this instance.
diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD
index d8c4d27b9..ea82f4987 100644
--- a/pkg/sentry/vfs/memxattr/BUILD
+++ b/pkg/sentry/vfs/memxattr/BUILD
@@ -8,6 +8,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go
index 638b5d830..9b7953fa3 100644
--- a/pkg/sentry/vfs/memxattr/xattr.go
+++ b/pkg/sentry/vfs/memxattr/xattr.go
@@ -17,7 +17,10 @@
package memxattr
import (
+ "strings"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -26,6 +29,9 @@ import (
// SimpleExtendedAttributes implements extended attributes using a map of
// names to values.
//
+// SimpleExtendedAttributes calls vfs.CheckXattrPermissions, so callers are not
+// required to do so.
+//
// +stateify savable
type SimpleExtendedAttributes struct {
// mu protects the below fields.
@@ -34,7 +40,11 @@ type SimpleExtendedAttributes struct {
}
// GetXattr returns the value at 'name'.
-func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) {
+func (x *SimpleExtendedAttributes) GetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.GetXattrOptions) (string, error) {
+ if err := vfs.CheckXattrPermissions(creds, vfs.MayRead, mode, kuid, opts.Name); err != nil {
+ return "", err
+ }
+
x.mu.RLock()
value, ok := x.xattrs[opts.Name]
x.mu.RUnlock()
@@ -50,7 +60,11 @@ func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string,
}
// SetXattr sets 'value' at 'name'.
-func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error {
+func (x *SimpleExtendedAttributes) SetXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, opts *vfs.SetXattrOptions) error {
+ if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, opts.Name); err != nil {
+ return err
+ }
+
x.mu.Lock()
defer x.mu.Unlock()
if x.xattrs == nil {
@@ -73,12 +87,19 @@ func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error {
}
// ListXattr returns all names in xattrs.
-func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) {
+func (x *SimpleExtendedAttributes) ListXattr(creds *auth.Credentials, size uint64) ([]string, error) {
// Keep track of the size of the buffer needed in listxattr(2) for the list.
listSize := 0
x.mu.RLock()
names := make([]string, 0, len(x.xattrs))
+ haveCap := creds.HasCapability(linux.CAP_SYS_ADMIN)
for n := range x.xattrs {
+ // Hide extended attributes in the "trusted" namespace from
+ // non-privileged users. This is consistent with Linux's
+ // fs/xattr.c:simple_xattr_list().
+ if !haveCap && strings.HasPrefix(n, linux.XATTR_TRUSTED_PREFIX) {
+ continue
+ }
names = append(names, n)
// Add one byte per null terminator.
listSize += len(n) + 1
@@ -91,7 +112,11 @@ func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) {
}
// RemoveXattr removes the xattr at 'name'.
-func (x *SimpleExtendedAttributes) RemoveXattr(name string) error {
+func (x *SimpleExtendedAttributes) RemoveXattr(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, name string) error {
+ if err := vfs.CheckXattrPermissions(creds, vfs.MayWrite, mode, kuid, name); err != nil {
+ return err
+ }
+
x.mu.Lock()
defer x.mu.Unlock()
if _, ok := x.xattrs[name]; !ok {
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 82fd382c2..03857dfc8 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
@@ -220,7 +221,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
vdDentry := vd.dentry
vdDentry.mu.Lock()
for {
- if vdDentry.dead {
+ if vd.mount.umounted || vdDentry.dead {
vdDentry.mu.Unlock()
vfs.mountMu.Unlock()
vd.DecRef(ctx)
@@ -284,7 +285,7 @@ func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentia
// UmountAt removes the Mount at the given path.
func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error {
if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// MNT_FORCE is currently unimplemented except for the permission check.
@@ -301,19 +302,19 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
}
defer vd.DecRef(ctx)
if vd.dentry != vd.mount.root {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
vfs.mountMu.Lock()
if mntns := MountNamespaceFromContext(ctx); mntns != nil {
defer mntns.DecRef(ctx)
if mntns != vd.mount.ns {
vfs.mountMu.Unlock()
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
if vd.mount == vd.mount.ns.root {
vfs.mountMu.Unlock()
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 87fdcf403..f2aabb905 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -42,6 +42,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -285,7 +286,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
if newpop.FollowFinalSymlink {
oldVD.DecRef(ctx)
ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rp := vfs.getResolvingPath(creds, newpop)
@@ -321,7 +322,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
// also honored." - mkdir(2)
@@ -359,7 +360,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rp := vfs.getResolvingPath(creds, pop)
@@ -402,13 +403,13 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
// filesystem implementations that do not support it).
if opts.Flags&linux.O_TMPFILE != 0 {
if opts.Flags&linux.O_DIRECTORY == 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
if opts.Flags&linux.O_CREAT != 0 {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
// O_PATH causes most other flags to be ignored.
@@ -499,7 +500,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
}
if oldpop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
oldParentVD, oldName, err := vfs.getParentDirAndName(ctx, creds, oldpop)
@@ -521,7 +522,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
if newpop.FollowFinalSymlink {
oldParentVD.DecRef(ctx)
ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rp := vfs.getResolvingPath(creds, newpop)
@@ -561,7 +562,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rp := vfs.getResolvingPath(creds, pop)
@@ -644,7 +645,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rp := vfs.getResolvingPath(creds, pop)
@@ -678,7 +679,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink")
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
rp := vfs.getResolvingPath(creds, pop)
@@ -731,8 +732,8 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede
rp.Release(ctx)
return names, nil
}
- if err == syserror.ENOTSUP {
- // Linux doesn't actually return ENOTSUP in this case; instead,
+ if linuxerr.Equals(linuxerr.EOPNOTSUPP, err) {
+ // Linux doesn't actually return EOPNOTSUPP in this case; instead,
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
// subsystem to return security extended attributes, which by
// default don't exist.
@@ -830,14 +831,14 @@ func (vfs *VirtualFilesystem) MkdirAllAt(ctx context.Context, currentPath string
Path: fspath.Parse(currentPath),
}
stat, err := vfs.StatAt(ctx, creds, pop, &StatOptions{Mask: linux.STATX_TYPE})
- switch err {
- case nil:
+ switch {
+ case err == nil:
if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.FileTypeMask != linux.ModeDirectory {
return syserror.ENOTDIR
}
// Directory already exists.
return nil
- case syserror.ENOENT:
+ case linuxerr.Equals(linuxerr.ENOENT, err):
// Expected, we will create the dir.
default:
return fmt.Errorf("stat failed for %q during directory creation: %w", currentPath, err)
@@ -871,7 +872,7 @@ func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, targe
Root: root,
Start: root,
Path: fspath.Parse(target),
- }, mkdirOpts); err != nil && err != syserror.EEXIST {
+ }, mkdirOpts); err != nil && !linuxerr.Equals(linuxerr.EEXIST, err) {
return fmt.Errorf("failed to create mountpoint %q: %w", target, err)
}
return nil
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index dfe85f31d..e8f7d1f01 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -77,11 +77,6 @@ var DefaultOpts = Opts{
// trigger it.
const descheduleThreshold = 1 * time.Second
-var (
- stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout")
- stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
-)
-
// Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck.
var stackDumpSameTaskPeriod = time.Minute
@@ -115,14 +110,14 @@ func (a *Action) Get() interface{} {
}
// String returns Action's string representation.
-func (a *Action) String() string {
- switch *a {
+func (a Action) String() string {
+ switch a {
case LogWarning:
return "logWarning"
case Panic:
return "panic"
default:
- panic(fmt.Sprintf("Invalid watchdog action: %d", *a))
+ panic(fmt.Sprintf("Invalid watchdog action: %d", a))
}
}
@@ -242,7 +237,6 @@ func (w *Watchdog) waitForStart() {
return
}
- stuckStartup.Increment()
metric.WeirdnessMetric.Increment("watchdog_stuck_startup")
var buf bytes.Buffer
@@ -316,7 +310,6 @@ func (w *Watchdog) runTurn() {
// unless they are surrounded by
// Task.UninterruptibleSleepStart/Finish.
tc = &offender{lastUpdateTime: lastUpdateTime}
- stuckTasks.Increment()
metric.WeirdnessMetric.Increment("watchdog_stuck_tasks")
newTaskFound = true
}
diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD
index fd6127b97..367765209 100644
--- a/pkg/shim/BUILD
+++ b/pkg/shim/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "shim",
srcs = [
"api.go",
+ "debug.go",
"epoll.go",
"options.go",
"service.go",
diff --git a/pkg/shim/debug.go b/pkg/shim/debug.go
new file mode 100644
index 000000000..49f01990e
--- /dev/null
+++ b/pkg/shim/debug.go
@@ -0,0 +1,48 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "os"
+ "os/signal"
+ "runtime"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/containerd/log"
+)
+
+var once sync.Once
+
+func setDebugSigHandler() {
+ once.Do(func() {
+ dumpCh := make(chan os.Signal, 1)
+ signal.Notify(dumpCh, syscall.SIGUSR2)
+ go func() {
+ buf := make([]byte, 10240)
+ for range dumpCh {
+ for {
+ n := runtime.Stack(buf, true)
+ if n >= len(buf) {
+ buf = make([]byte, 2*len(buf))
+ continue
+ }
+ log.L.Debugf("User requested stack trace:\n%s", buf[:n])
+ }
+ }
+ }()
+ log.L.Debugf("For full process dump run: kill -%d %d", syscall.SIGUSR2, os.Getpid())
+ })
+}
diff --git a/pkg/shim/proc/BUILD b/pkg/shim/proc/BUILD
index 544bdc170..c8527a6d9 100644
--- a/pkg/shim/proc/BUILD
+++ b/pkg/shim/proc/BUILD
@@ -20,7 +20,9 @@ go_library(
"//shim:__subpackages__",
],
deps = [
+ "//pkg/cleanup",
"//pkg/shim/runsc",
+ "//pkg/shim/utils",
"@com_github_containerd_console//:go_default_library",
"@com_github_containerd_containerd//errdefs:go_default_library",
"@com_github_containerd_containerd//log:go_default_library",
diff --git a/pkg/shim/proc/deleted_state.go b/pkg/shim/proc/deleted_state.go
index d9b970c4d..b0bbe4d7e 100644
--- a/pkg/shim/proc/deleted_state.go
+++ b/pkg/shim/proc/deleted_state.go
@@ -22,28 +22,38 @@ import (
"github.com/containerd/console"
"github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/pkg/process"
+ runc "github.com/containerd/go-runc"
)
type deletedState struct{}
-func (*deletedState) Resize(ws console.WinSize) error {
- return fmt.Errorf("cannot resize a deleted process.ss")
+func (*deletedState) Resize(console.WinSize) error {
+ return fmt.Errorf("cannot resize a deleted container/process")
}
-func (*deletedState) Start(ctx context.Context) error {
- return fmt.Errorf("cannot start a deleted process.ss")
+func (*deletedState) Start(context.Context) error {
+ return fmt.Errorf("cannot start a deleted container/process")
}
-func (*deletedState) Delete(ctx context.Context) error {
- return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound)
+func (*deletedState) Delete(context.Context) error {
+ return fmt.Errorf("cannot delete a deleted container/process: %w", errdefs.ErrNotFound)
}
-func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error {
- return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound)
+func (*deletedState) Kill(_ context.Context, signal uint32, _ bool) error {
+ return handleStoppedKill(signal)
}
-func (*deletedState) SetExited(status int) {}
+func (*deletedState) SetExited(int) {}
-func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+func (*deletedState) Exec(context.Context, string, *ExecConfig) (process.Process, error) {
return nil, fmt.Errorf("cannot exec in a deleted state")
}
+
+func (s *deletedState) State(context.Context) (string, error) {
+ // There is no "deleted" state, closest one is stopped.
+ return "stopped", nil
+}
+
+func (s *deletedState) Stats(context.Context, string) (*runc.Stats, error) {
+ return nil, fmt.Errorf("cannot stat a stopped container/process")
+}
diff --git a/pkg/shim/proc/exec.go b/pkg/shim/proc/exec.go
index e7968d9d5..e0f2ae6fa 100644
--- a/pkg/shim/proc/exec.go
+++ b/pkg/shim/proc/exec.go
@@ -26,11 +26,13 @@ import (
"github.com/containerd/console"
"github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/log"
"github.com/containerd/containerd/pkg/stdio"
"github.com/containerd/fifo"
runc "github.com/containerd/go-runc"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/shim/runsc"
)
@@ -92,6 +94,12 @@ func (e *execProcess) SetExited(status int) {
}
func (e *execProcess) setExited(status int) {
+ if !e.exited.IsZero() {
+ log.L.Debugf("Exec: status already set to %d, ignoring status: %d", e.status, status)
+ return
+ }
+
+ log.L.Debugf("Exec: setting status: %d", status)
e.status = status
e.exited = time.Now()
e.parent.Platform.ShutdownConsole(context.Background(), e.console)
@@ -105,7 +113,7 @@ func (e *execProcess) Delete(ctx context.Context) error {
return e.execState.Delete(ctx)
}
-func (e *execProcess) delete(ctx context.Context) error {
+func (e *execProcess) delete() error {
e.wg.Wait()
if e.io != nil {
for _, c := range e.closers {
@@ -113,12 +121,6 @@ func (e *execProcess) delete(ctx context.Context) error {
}
e.io.Close()
}
- pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
- // silently ignore error
- os.Remove(pidfile)
- internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
- // silently ignore error
- os.Remove(internalPidfile)
return nil
}
@@ -145,16 +147,13 @@ func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error {
func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error {
internalPid := e.internalPid
- if internalPid != 0 {
- if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{
- Pid: internalPid,
- }); err != nil {
- // If this returns error, consider the process has
- // already stopped.
- //
- // TODO: Fix after signal handling is fixed.
- return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound)
- }
+ if internalPid == 0 {
+ return nil
+ }
+
+ opts := runsc.KillOpts{Pid: internalPid}
+ if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &opts); err != nil {
+ return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound)
}
return nil
}
@@ -174,42 +173,53 @@ func (e *execProcess) Start(ctx context.Context) error {
return e.execState.Start(ctx)
}
-func (e *execProcess) start(ctx context.Context) (err error) {
- var (
- socket *runc.Socket
- pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
- internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
- )
- if e.stdio.Terminal {
- if socket, err = runc.NewTempConsoleSocket(); err != nil {
+func (e *execProcess) start(ctx context.Context) error {
+ var socket *runc.Socket
+
+ switch {
+ case e.stdio.Terminal:
+ s, err := runc.NewTempConsoleSocket()
+ if err != nil {
return fmt.Errorf("failed to create runc console socket: %w", err)
}
- defer socket.Close()
- } else if e.stdio.IsNull() {
- if e.io, err = runc.NewNullIO(); err != nil {
+ defer s.Close()
+ socket = s
+
+ case e.stdio.IsNull():
+ io, err := runc.NewNullIO()
+ if err != nil {
return fmt.Errorf("creating new NULL IO: %w", err)
}
- } else {
- if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil {
+ e.io = io
+
+ default:
+ io, err := runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio))
+ if err != nil {
return fmt.Errorf("failed to create runc io pipes: %w", err)
}
+ e.io = io
}
+
opts := &runsc.ExecOpts{
- PidFile: pidfile,
- InternalPidFile: internalPidfile,
+ PidFile: filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)),
+ InternalPidFile: filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)),
IO: e.io,
Detach: true,
}
+ defer func() {
+ _ = os.Remove(opts.PidFile)
+ _ = os.Remove(opts.InternalPidFile)
+ }()
if socket != nil {
opts.ConsoleSocket = socket
}
+
eventCh := e.parent.Monitor.Subscribe()
- defer func() {
- // Unsubscribe if an error is returned.
- if err != nil {
- e.parent.Monitor.Unsubscribe(eventCh)
- }
- }()
+ cu := cleanup.Make(func() {
+ e.parent.Monitor.Unsubscribe(eventCh)
+ })
+ defer cu.Clean()
+
if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil {
close(e.waitBlock)
return e.parent.runtimeError(err, "OCI runtime exec failed")
@@ -237,6 +247,7 @@ func (e *execProcess) start(ctx context.Context) (err error) {
return fmt.Errorf("failed to start io pipe copy: %w", err)
}
}
+
pid, err := runc.ReadPidFile(opts.PidFile)
if err != nil {
return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err)
@@ -247,6 +258,7 @@ func (e *execProcess) start(ctx context.Context) (err error) {
return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err)
}
e.internalPid = internalPid
+
go func() {
defer e.parent.Monitor.Unsubscribe(eventCh)
for event := range eventCh {
@@ -260,21 +272,25 @@ func (e *execProcess) start(ctx context.Context) (err error) {
}
}
}()
+
+ cu.Release() // cancel cleanup on success.
return nil
}
-func (e *execProcess) Status(ctx context.Context) (string, error) {
+func (e *execProcess) Status(context.Context) (string, error) {
e.mu.Lock()
defer e.mu.Unlock()
// if we don't have a pid then the exec process has just been created
if e.pid == 0 {
return "created", nil
}
- // if we have a pid and it can be signaled, the process is running
- // TODO(random-liu): Use `runsc kill --pid`.
- if err := unix.Kill(e.pid, 0); err == nil {
- return "running", nil
+ // This checks that `runsc exec` process is still running. This process has
+ // the same lifetime as the process executing inside the container. So instead
+ // of calling `runsc kill --pid`, just do a quick check that `runsc exec` is
+ // still running.
+ if err := unix.Kill(e.pid, 0); err != nil {
+ // Can't signal the process, it must have exited.
+ return "stopped", nil
}
- // else if we have a pid but it can nolonger be signaled, it has stopped
- return "stopped", nil
+ return "running", nil
}
diff --git a/pkg/shim/proc/exec_state.go b/pkg/shim/proc/exec_state.go
index 4dcda8b44..8d8ecf541 100644
--- a/pkg/shim/proc/exec_state.go
+++ b/pkg/shim/proc/exec_state.go
@@ -34,18 +34,21 @@ type execCreatedState struct {
p *execProcess
}
-func (s *execCreatedState) transition(name string) error {
- switch name {
- case "running":
+func (s *execCreatedState) name() string {
+ return "created"
+}
+
+func (s *execCreatedState) transition(transition stateTransition) {
+ switch transition {
+ case running:
s.p.execState = &execRunningState{p: s.p}
- case "stopped":
+ case stopped:
s.p.execState = &execStoppedState{p: s.p}
- case "deleted":
+ case deleted:
s.p.execState = &deletedState{}
default:
- return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition))
}
- return nil
}
func (s *execCreatedState) Resize(ws console.WinSize) error {
@@ -56,14 +59,16 @@ func (s *execCreatedState) Start(ctx context.Context) error {
if err := s.p.start(ctx); err != nil {
return err
}
- return s.transition("running")
+ s.transition(running)
+ return nil
}
-func (s *execCreatedState) Delete(ctx context.Context) error {
- if err := s.p.delete(ctx); err != nil {
+func (s *execCreatedState) Delete(context.Context) error {
+ if err := s.p.delete(); err != nil {
return err
}
- return s.transition("deleted")
+ s.transition(deleted)
+ return nil
}
func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error {
@@ -72,35 +77,35 @@ func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error
func (s *execCreatedState) SetExited(status int) {
s.p.setExited(status)
-
- if err := s.transition("stopped"); err != nil {
- panic(err)
- }
+ s.transition(stopped)
}
type execRunningState struct {
p *execProcess
}
-func (s *execRunningState) transition(name string) error {
- switch name {
- case "stopped":
+func (s *execRunningState) name() string {
+ return "running"
+}
+
+func (s *execRunningState) transition(transition stateTransition) {
+ switch transition {
+ case stopped:
s.p.execState = &execStoppedState{p: s.p}
default:
- return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition))
}
- return nil
}
func (s *execRunningState) Resize(ws console.WinSize) error {
return s.p.resize(ws)
}
-func (s *execRunningState) Start(ctx context.Context) error {
+func (s *execRunningState) Start(context.Context) error {
return fmt.Errorf("cannot start a running process")
}
-func (s *execRunningState) Delete(ctx context.Context) error {
+func (s *execRunningState) Delete(context.Context) error {
return fmt.Errorf("cannot delete a running process")
}
@@ -110,45 +115,46 @@ func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error
func (s *execRunningState) SetExited(status int) {
s.p.setExited(status)
-
- if err := s.transition("stopped"); err != nil {
- panic(err)
- }
+ s.transition(stopped)
}
type execStoppedState struct {
p *execProcess
}
-func (s *execStoppedState) transition(name string) error {
- switch name {
- case "deleted":
+func (s *execStoppedState) name() string {
+ return "stopped"
+}
+
+func (s *execStoppedState) transition(transition stateTransition) {
+ switch transition {
+ case deleted:
s.p.execState = &deletedState{}
default:
- return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition))
}
- return nil
}
-func (s *execStoppedState) Resize(ws console.WinSize) error {
+func (s *execStoppedState) Resize(console.WinSize) error {
return fmt.Errorf("cannot resize a stopped container")
}
-func (s *execStoppedState) Start(ctx context.Context) error {
+func (s *execStoppedState) Start(context.Context) error {
return fmt.Errorf("cannot start a stopped process")
}
-func (s *execStoppedState) Delete(ctx context.Context) error {
- if err := s.p.delete(ctx); err != nil {
+func (s *execStoppedState) Delete(context.Context) error {
+ if err := s.p.delete(); err != nil {
return err
}
- return s.transition("deleted")
+ s.transition(deleted)
+ return nil
}
-func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
- return s.p.kill(ctx, sig, all)
+func (s *execStoppedState) Kill(_ context.Context, sig uint32, _ bool) error {
+ return handleStoppedKill(sig)
}
-func (s *execStoppedState) SetExited(status int) {
+func (s *execStoppedState) SetExited(int) {
// no op
}
diff --git a/pkg/shim/proc/init.go b/pkg/shim/proc/init.go
index 664465e0d..6bf090813 100644
--- a/pkg/shim/proc/init.go
+++ b/pkg/shim/proc/init.go
@@ -39,6 +39,8 @@ import (
"gvisor.dev/gvisor/pkg/shim/runsc"
)
+const statusStopped = "stopped"
+
// Init represents an initial process for a container.
type Init struct {
wg sync.WaitGroup
@@ -201,10 +203,15 @@ func (p *Init) ExitedAt() time.Time {
func (p *Init) Status(ctx context.Context) (string, error) {
p.mu.Lock()
defer p.mu.Unlock()
+
+ return p.initState.State(ctx)
+}
+
+func (p *Init) state(ctx context.Context) (string, error) {
c, err := p.runtime.State(ctx, p.id)
if err != nil {
if strings.Contains(err.Error(), "does not exist") {
- return "stopped", nil
+ return statusStopped, nil
}
return "", p.runtimeError(err, "OCI runtime state failed")
}
@@ -231,10 +238,7 @@ func (p *Init) start(ctx context.Context) error {
status, err := p.runtime.Wait(context.Background(), p.id)
if err != nil {
log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id)
- // TODO(random-liu): Handle runsc kill error.
- if err := p.killAll(ctx); err != nil {
- log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id)
- }
+ p.killAllLocked(ctx)
status = internalErrorCode
}
ExitCh <- Exit{
@@ -255,6 +259,12 @@ func (p *Init) SetExited(status int) {
}
func (p *Init) setExited(status int) {
+ if !p.exited.IsZero() {
+ log.L.Debugf("Status already set to %d, ignoring status: %d", p.status, status)
+ return
+ }
+
+ log.L.Debugf("Setting status: %d", status)
p.exited = time.Now()
p.status = status
p.Platform.ShutdownConsole(context.Background(), p.console)
@@ -270,15 +280,16 @@ func (p *Init) Delete(ctx context.Context) error {
}
func (p *Init) delete(ctx context.Context) error {
- p.killAll(ctx)
+ p.killAllLocked(ctx)
p.wg.Wait()
+
err := p.runtime.Delete(ctx, p.id, nil)
- // ignore errors if a runtime has already deleted the process
- // but we still hold metadata and pipes
- //
- // this is common during a checkpoint, runc will delete the container state
- // after a checkpoint and the container will no longer exist within runc
if err != nil {
+ // ignore errors if a runtime has already deleted the process
+ // but we still hold metadata and pipes
+ //
+ // this is common during a checkpoint, runc will delete the container state
+ // after a checkpoint and the container will no longer exist within runc
if strings.Contains(err.Error(), "does not exist") {
err = nil
} else {
@@ -326,29 +337,24 @@ func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error {
return p.initState.Kill(ctx, signal, all)
}
-func (p *Init) kill(context context.Context, signal uint32, all bool) error {
+func (p *Init) kill(ctx context.Context, signal uint32, all bool) error {
var (
killErr error
backoff = 100 * time.Millisecond
)
- timeout := 1 * time.Second
- for start := time.Now(); time.Now().Sub(start) < timeout; {
- c, err := p.runtime.State(context, p.id)
+ const timeout = time.Second
+ for start := time.Now(); time.Since(start) < timeout; {
+ state, err := p.initState.State(ctx)
if err != nil {
- if strings.Contains(err.Error(), "does not exist") {
- return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
- }
return p.runtimeError(err, "OCI runtime state failed")
}
// For runsc, signal only works when container is running state.
// If the container is not in running state, directly return
// "no such process"
- if p.convertStatus(c.Status) == "stopped" {
+ if state == statusStopped {
return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
}
- killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{
- All: all,
- })
+ killErr = p.runtime.Kill(ctx, p.id, int(signal), &runsc.KillOpts{All: all})
if killErr == nil {
return nil
}
@@ -358,22 +364,18 @@ func (p *Init) kill(context context.Context, signal uint32, all bool) error {
return p.runtimeError(killErr, "kill timeout")
}
-// KillAll kills all processes belonging to the init process.
-func (p *Init) KillAll(context context.Context) error {
+// KillAll kills all processes belonging to the init process. If
+// `runsc kill --all` returns error, assume the container has already stopped.
+func (p *Init) KillAll(context context.Context) {
p.mu.Lock()
defer p.mu.Unlock()
- return p.killAll(context)
+ p.killAllLocked(context)
}
-func (p *Init) killAll(context context.Context) error {
- p.runtime.Kill(context, p.id, int(unix.SIGKILL), &runsc.KillOpts{
- All: true,
- })
- // Ignore error handling for `runsc kill --all` for now.
- // * If it doesn't return error, it is good;
- // * If it returns error, consider the container has already stopped.
- // TODO: Fix `runsc kill --all` error handling.
- return nil
+func (p *Init) killAllLocked(context context.Context) {
+ if err := p.runtime.Kill(context, p.id, int(unix.SIGKILL), &runsc.KillOpts{All: true}); err != nil {
+ log.L.Warningf("Ignoring error killing container %q: %v", p.id, err)
+ }
}
// Stdin returns the stdin of the process.
@@ -396,7 +398,6 @@ func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Pr
// exec returns a new exec'd process.
func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) {
- // process exec request
var spec specs.Process
if err := json.Unmarshal(r.Spec.Value, &spec); err != nil {
return nil, err
@@ -420,6 +421,17 @@ func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) {
return e, nil
}
+func (p *Init) Stats(ctx context.Context, id string) (*runc.Stats, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Stats(ctx, id)
+}
+
+func (p *Init) stats(ctx context.Context, id string) (*runc.Stats, error) {
+ return p.Runtime().Stats(ctx, id)
+}
+
// Stdio returns the stdio of the process.
func (p *Init) Stdio() stdio.Stdio {
return p.stdio
@@ -444,7 +456,7 @@ func (p *Init) runtimeError(rErr error, msg string) error {
func (p *Init) convertStatus(status string) string {
if status == "created" && !p.Sandbox && p.status == internalErrorCode {
// Treat start failure state for non-root container as stopped.
- return "stopped"
+ return statusStopped
}
return status
}
diff --git a/pkg/shim/proc/init_state.go b/pkg/shim/proc/init_state.go
index 0065fc385..5347ddefe 100644
--- a/pkg/shim/proc/init_state.go
+++ b/pkg/shim/proc/init_state.go
@@ -19,16 +19,40 @@ import (
"context"
"fmt"
- "github.com/containerd/console"
"github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/pkg/process"
+ runc "github.com/containerd/go-runc"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/shim/utils"
)
+type stateTransition int
+
+const (
+ running stateTransition = iota
+ stopped
+ deleted
+)
+
+func (s stateTransition) String() string {
+ switch s {
+ case running:
+ return "running"
+ case stopped:
+ return "stopped"
+ case deleted:
+ return "deleted"
+ default:
+ panic(fmt.Sprintf("unknown state: %d", s))
+ }
+}
+
type initState interface {
- Resize(console.WinSize) error
Start(context.Context) error
Delete(context.Context) error
Exec(context.Context, string, *ExecConfig) (process.Process, error)
+ State(ctx context.Context) (string, error)
+ Stats(context.Context, string) (*runc.Stats, error)
Kill(context.Context, uint32, bool) error
SetExited(int)
}
@@ -37,22 +61,21 @@ type createdState struct {
p *Init
}
-func (s *createdState) transition(name string) error {
- switch name {
- case "running":
+func (s *createdState) name() string {
+ return "created"
+}
+
+func (s *createdState) transition(transition stateTransition) {
+ switch transition {
+ case running:
s.p.initState = &runningState{p: s.p}
- case "stopped":
- s.p.initState = &stoppedState{p: s.p}
- case "deleted":
+ case stopped:
+ s.p.initState = &stoppedState{process: s.p}
+ case deleted:
s.p.initState = &deletedState{}
default:
- return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition))
}
- return nil
-}
-
-func (s *createdState) Resize(ws console.WinSize) error {
- return s.p.resize(ws)
}
func (s *createdState) Start(ctx context.Context) error {
@@ -66,20 +89,20 @@ func (s *createdState) Start(ctx context.Context) error {
if !s.p.Sandbox {
s.p.io.Close()
s.p.setExited(internalErrorCode)
- if err := s.transition("stopped"); err != nil {
- panic(err)
- }
+ s.transition(stopped)
}
return err
}
- return s.transition("running")
+ s.transition(running)
+ return nil
}
func (s *createdState) Delete(ctx context.Context) error {
if err := s.p.delete(ctx); err != nil {
return err
}
- return s.transition("deleted")
+ s.transition(deleted)
+ return nil
}
func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error {
@@ -88,40 +111,48 @@ func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error {
func (s *createdState) SetExited(status int) {
s.p.setExited(status)
-
- if err := s.transition("stopped"); err != nil {
- panic(err)
- }
+ s.transition(stopped)
}
func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
return s.p.exec(path, r)
}
+func (s *createdState) State(ctx context.Context) (string, error) {
+ state, err := s.p.state(ctx)
+ if err == nil && state == statusStopped {
+ s.transition(stopped)
+ }
+ return state, err
+}
+
+func (s *createdState) Stats(ctx context.Context, id string) (*runc.Stats, error) {
+ return s.p.stats(ctx, id)
+}
+
type runningState struct {
p *Init
}
-func (s *runningState) transition(name string) error {
- switch name {
- case "stopped":
- s.p.initState = &stoppedState{p: s.p}
- default:
- return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
- }
- return nil
+func (s *runningState) name() string {
+ return "running"
}
-func (s *runningState) Resize(ws console.WinSize) error {
- return s.p.resize(ws)
+func (s *runningState) transition(transition stateTransition) {
+ switch transition {
+ case stopped:
+ s.p.initState = &stoppedState{process: s.p}
+ default:
+ panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition))
+ }
}
func (s *runningState) Start(ctx context.Context) error {
- return fmt.Errorf("cannot start a running process.ss")
+ return fmt.Errorf("cannot start a running container")
}
func (s *runningState) Delete(ctx context.Context) error {
- return fmt.Errorf("cannot delete a running process.ss")
+ return fmt.Errorf("cannot delete a running container")
}
func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error {
@@ -130,53 +161,81 @@ func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error {
func (s *runningState) SetExited(status int) {
s.p.setExited(status)
+ s.transition(stopped)
+}
+
+func (s *runningState) Exec(_ context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return s.p.exec(path, r)
+}
- if err := s.transition("stopped"); err != nil {
- panic(err)
+func (s *runningState) State(ctx context.Context) (string, error) {
+ state, err := s.p.state(ctx)
+ if err == nil && state == "stopped" {
+ s.transition(stopped)
}
+ return state, err
}
-func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
- return s.p.exec(path, r)
+func (s *runningState) Stats(ctx context.Context, id string) (*runc.Stats, error) {
+ return s.p.stats(ctx, id)
}
type stoppedState struct {
- p *Init
+ process *Init
}
-func (s *stoppedState) transition(name string) error {
- switch name {
- case "deleted":
- s.p.initState = &deletedState{}
- default:
- return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
- }
- return nil
+func (s *stoppedState) name() string {
+ return "stopped"
}
-func (s *stoppedState) Resize(ws console.WinSize) error {
- return fmt.Errorf("cannot resize a stopped container")
+func (s *stoppedState) transition(transition stateTransition) {
+ switch transition {
+ case deleted:
+ s.process.initState = &deletedState{}
+ default:
+ panic(fmt.Sprintf("invalid state transition %q to %q", s.name(), transition))
+ }
}
-func (s *stoppedState) Start(ctx context.Context) error {
- return fmt.Errorf("cannot start a stopped process.ss")
+func (s *stoppedState) Start(context.Context) error {
+ return fmt.Errorf("cannot start a stopped container")
}
func (s *stoppedState) Delete(ctx context.Context) error {
- if err := s.p.delete(ctx); err != nil {
+ if err := s.process.delete(ctx); err != nil {
return err
}
- return s.transition("deleted")
+ s.transition(deleted)
+ return nil
}
-func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
- return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id)
+func (s *stoppedState) Kill(_ context.Context, signal uint32, _ bool) error {
+ return handleStoppedKill(signal)
}
func (s *stoppedState) SetExited(status int) {
- // no op
+ s.process.setExited(status)
}
-func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+func (s *stoppedState) Exec(context.Context, string, *ExecConfig) (process.Process, error) {
return nil, fmt.Errorf("cannot exec in a stopped state")
}
+
+func (s *stoppedState) State(context.Context) (string, error) {
+ return "stopped", nil
+}
+
+func (s *stoppedState) Stats(context.Context, string) (*runc.Stats, error) {
+ return nil, fmt.Errorf("cannot stat a stopped container")
+}
+
+func handleStoppedKill(signal uint32) error {
+ switch unix.Signal(signal) {
+ case unix.SIGTERM, unix.SIGKILL:
+ // Container is already stopped, so everything inside the container has
+ // already been killed.
+ return nil
+ default:
+ return utils.ErrToGRPCf(errdefs.ErrNotFound, "process not found")
+ }
+}
diff --git a/pkg/shim/proc/proc.go b/pkg/shim/proc/proc.go
index edba3fca5..89ad3f505 100644
--- a/pkg/shim/proc/proc.go
+++ b/pkg/shim/proc/proc.go
@@ -17,23 +17,5 @@
// the sandbox process running the container.
package proc
-import (
- "fmt"
-)
-
// RunscRoot is the path to the root runsc state directory.
const RunscRoot = "/run/containerd/runsc"
-
-func stateName(v interface{}) string {
- switch v.(type) {
- case *runningState, *execRunningState:
- return "running"
- case *createdState, *execCreatedState:
- return "created"
- case *deletedState:
- return "deleted"
- case *stoppedState:
- return "stopped"
- }
- panic(fmt.Errorf("invalid state %v", v))
-}
diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go
index ff0521d73..888cb0bcb 100644
--- a/pkg/shim/runsc/runsc.go
+++ b/pkg/shim/runsc/runsc.go
@@ -17,6 +17,7 @@
package runsc
import (
+ "bytes"
"context"
"encoding/json"
"fmt"
@@ -73,9 +74,9 @@ type Runsc struct {
// List returns all containers created inside the provided runsc root directory.
func (r *Runsc) List(context context.Context) ([]*runc.Container, error) {
- data, err := cmdOutput(r.command(context, "list", "--format=json"), false)
+ data, stderr, err := cmdOutput(r.command(context, "list", "--format=json"), false)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("%w: %s", err, stderr)
}
var out []*runc.Container
if err := json.Unmarshal(data, &out); err != nil {
@@ -86,9 +87,9 @@ func (r *Runsc) List(context context.Context) ([]*runc.Container, error) {
// State returns the state for the container provided by id.
func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) {
- data, err := cmdOutput(r.command(context, "state", id), true)
+ data, stderr, err := cmdOutput(r.command(context, "state", id), false)
if err != nil {
- return nil, fmt.Errorf("%s: %s", err, data)
+ return nil, fmt.Errorf("%w: %s", err, stderr)
}
var c runc.Container
if err := json.Unmarshal(data, &c); err != nil {
@@ -142,9 +143,9 @@ func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateO
}
if cmd.Stdout == nil && cmd.Stderr == nil {
- data, err := cmdOutput(cmd, true)
+ out, _, err := cmdOutput(cmd, true)
if err != nil {
- return fmt.Errorf("%s: %s", err, data)
+ return fmt.Errorf("%w: %s", err, out)
}
return nil
}
@@ -168,15 +169,15 @@ func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateO
}
func (r *Runsc) Pause(context context.Context, id string) error {
- if _, err := cmdOutput(r.command(context, "pause", id), true); err != nil {
- return fmt.Errorf("unable to pause: %w", err)
+ if out, _, err := cmdOutput(r.command(context, "pause", id), true); err != nil {
+ return fmt.Errorf("unable to pause: %w: %s", err, out)
}
return nil
}
func (r *Runsc) Resume(context context.Context, id string) error {
- if _, err := cmdOutput(r.command(context, "resume", id), true); err != nil {
- return fmt.Errorf("unable to resume: %w", err)
+ if out, _, err := cmdOutput(r.command(context, "resume", id), true); err != nil {
+ return fmt.Errorf("unable to resume: %w: %s", err, out)
}
return nil
}
@@ -189,9 +190,9 @@ func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error {
}
if cmd.Stdout == nil && cmd.Stderr == nil {
- data, err := cmdOutput(cmd, true)
+ out, _, err := cmdOutput(cmd, true)
if err != nil {
- return fmt.Errorf("%s: %s", err, data)
+ return fmt.Errorf("%w: %s", err, out)
}
return nil
}
@@ -221,12 +222,10 @@ type waitResult struct {
}
// Wait will wait for a running container, and return its exit status.
-//
-// TODO(random-liu): Add exec process support.
func (r *Runsc) Wait(context context.Context, id string) (int, error) {
- data, err := cmdOutput(r.command(context, "wait", id), true)
+ data, stderr, err := cmdOutput(r.command(context, "wait", id), false)
if err != nil {
- return 0, fmt.Errorf("%s: %s", err, data)
+ return 0, fmt.Errorf("%w: %s", err, stderr)
}
var res waitResult
if err := json.Unmarshal(data, &res); err != nil {
@@ -294,9 +293,9 @@ func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opt
opts.Set(cmd)
}
if cmd.Stdout == nil && cmd.Stderr == nil {
- data, err := cmdOutput(cmd, true)
+ out, _, err := cmdOutput(cmd, true)
if err != nil {
- return fmt.Errorf("%s: %s", err, data)
+ return fmt.Errorf("%w: %s", err, out)
}
return nil
}
@@ -391,20 +390,12 @@ func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts
// Stats return the stats for a container like cpu, memory, and I/O.
func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) {
cmd := r.command(context, "events", "--stats", id)
- rd, err := cmd.StdoutPipe()
- if err != nil {
- return nil, err
- }
- ec, err := Monitor.Start(cmd)
+ data, stderr, err := cmdOutput(cmd, false)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("%w: %s", err, stderr)
}
- defer func() {
- rd.Close()
- Monitor.Wait(cmd, ec)
- }()
var e runc.Event
- if err := json.NewDecoder(rd).Decode(&e); err != nil {
+ if err := json.Unmarshal(data, &e); err != nil {
log.L.Debugf("Parsing events error: %v", err)
return nil, err
}
@@ -459,9 +450,9 @@ func (r *Runsc) Events(context context.Context, id string, interval time.Duratio
// Ps lists all the processes inside the container returning their pids.
func (r *Runsc) Ps(context context.Context, id string) ([]int, error) {
- data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true)
+ data, stderr, err := cmdOutput(r.command(context, "ps", "--format", "json", id), false)
if err != nil {
- return nil, fmt.Errorf("%s: %s", err, data)
+ return nil, fmt.Errorf("%w: %s", err, stderr)
}
var pids []int
if err := json.Unmarshal(data, &pids); err != nil {
@@ -472,9 +463,9 @@ func (r *Runsc) Ps(context context.Context, id string) ([]int, error) {
// Top lists all the processes inside the container returning the full ps data.
func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) {
- data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true)
+ data, stderr, err := cmdOutput(r.command(context, "ps", "--format", "table", id), false)
if err != nil {
- return nil, fmt.Errorf("%s: %s", err, data)
+ return nil, fmt.Errorf("%w: %s", err, stderr)
}
topResults, err := runc.ParsePSOutput(data)
@@ -517,9 +508,9 @@ func (r *Runsc) runOrError(cmd *exec.Cmd) error {
}
return err
}
- data, err := cmdOutput(cmd, true)
+ out, _, err := cmdOutput(cmd, true)
if err != nil {
- return fmt.Errorf("%s: %s", err, data)
+ return fmt.Errorf("%w: %s", err, out)
}
return nil
}
@@ -540,23 +531,29 @@ func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd {
return cmd
}
-func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) {
- b := getBuf()
- defer putBuf(b)
+func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, []byte, error) {
+ stdout := getBuf()
+ defer putBuf(stdout)
+ cmd.Stdout = stdout
+ cmd.Stderr = stdout
- cmd.Stdout = b
- if combined {
- cmd.Stderr = b
+ var stderr *bytes.Buffer
+ if !combined {
+ stderr = getBuf()
+ defer putBuf(stderr)
+ cmd.Stderr = stderr
}
ec, err := Monitor.Start(cmd)
if err != nil {
- return nil, err
+ return nil, nil, err
}
status, err := Monitor.Wait(cmd, ec)
if err == nil && status != 0 {
- err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ err = fmt.Errorf("%q did not terminate sucessfully", cmd.Args[0])
}
-
- return b.Bytes(), err
+ if stderr == nil {
+ return stdout.Bytes(), nil, err
+ }
+ return stdout.Bytes(), stderr.Bytes(), err
}
diff --git a/pkg/shim/service.go b/pkg/shim/service.go
index 1f9adcb65..24e3b7a82 100644
--- a/pkg/shim/service.go
+++ b/pkg/shim/service.go
@@ -81,8 +81,6 @@ const (
// New returns a new shim service that can be used via GRPC.
func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) {
- log.L.Debugf("service.New, id: %s", id)
-
var opts shim.Opts
if ctxOpts := ctx.Value(shim.OptsKey{}); ctxOpts != nil {
opts = ctxOpts.(shim.Opts)
@@ -304,8 +302,6 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error)
// Create creates a new initial process and container with the underlying OCI
// runtime.
func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*taskAPI.CreateTaskResponse, error) {
- log.L.Debugf("Create, id: %s, bundle: %q", r.ID, r.Bundle)
-
s.mu.Lock()
defer s.mu.Unlock()
@@ -396,6 +392,9 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*ta
log.L.Debugf("stdout: %s", r.Stdout)
log.L.Debugf("stderr: %s", r.Stderr)
log.L.Debugf("***************************")
+ if log.L.Logger.IsLevelEnabled(logrus.DebugLevel) {
+ setDebugSigHandler()
+ }
}
// Save state before any action is taken to ensure Cleanup() will have all
@@ -453,10 +452,10 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*ta
}
process, err := newInit(r.Bundle, filepath.Join(r.Bundle, "work"), ns, s.platform, config, &s.opts, st.Rootfs)
if err != nil {
- return nil, errdefs.ToGRPC(err)
+ return nil, utils.ErrToGRPC(err)
}
if err := process.Create(ctx, config); err != nil {
- return nil, errdefs.ToGRPC(err)
+ return nil, utils.ErrToGRPC(err)
}
// Set up OOM notification on the sandbox's cgroup. This is done on
@@ -506,9 +505,6 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP
if err != nil {
return nil, err
}
- if p == nil {
- return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
- }
if err := p.Delete(ctx); err != nil {
return nil, err
}
@@ -534,10 +530,10 @@ func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*typ
p := s.processes[r.ExecID]
s.mu.Unlock()
if p != nil {
- return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID)
+ return nil, utils.ErrToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID)
}
if s.task == nil {
- return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
}
process, err := s.task.Exec(ctx, s.bundle, &proc.ExecConfig{
ID: r.ExecID,
@@ -548,7 +544,7 @@ func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*typ
Spec: r.Spec,
})
if err != nil {
- return nil, errdefs.ToGRPC(err)
+ return nil, utils.ErrToGRPC(err)
}
s.mu.Lock()
s.processes[r.ExecID] = process
@@ -569,7 +565,7 @@ func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*
Height: uint16(r.Height),
}
if err := p.Resize(ws); err != nil {
- return nil, errdefs.ToGRPC(err)
+ return nil, utils.ErrToGRPC(err)
}
return empty, nil
}
@@ -580,10 +576,12 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.
p, err := s.getProcess(r.ExecID)
if err != nil {
+ log.L.Debugf("State failed to find process: %v", err)
return nil, err
}
st, err := p.Status(ctx)
if err != nil {
+ log.L.Debugf("State failed: %v", err)
return nil, err
}
status := task.StatusUnknown
@@ -596,7 +594,7 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.
status = task.StatusStopped
}
sio := p.Stdio()
- return &taskAPI.StateResponse{
+ res := &taskAPI.StateResponse{
ID: p.ID(),
Bundle: s.bundle,
Pid: uint32(p.Pid()),
@@ -607,7 +605,9 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.
Terminal: sio.Terminal,
ExitStatus: uint32(p.ExitStatus()),
ExitedAt: p.ExitedAt(),
- }, nil
+ }
+ log.L.Debugf("State succeeded, response: %+v", res)
+ return res, nil
}
// Pause the container.
@@ -615,7 +615,7 @@ func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Em
log.L.Debugf("Pause, id: %s", r.ID)
if s.task == nil {
log.L.Debugf("Pause error, id: %s: container not created", r.ID)
- return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
}
err := s.task.Runtime().Pause(ctx, r.ID)
if err != nil {
@@ -629,7 +629,7 @@ func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.
log.L.Debugf("Resume, id: %s", r.ID)
if s.task == nil {
log.L.Debugf("Resume error, id: %s: container not created", r.ID)
- return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
}
err := s.task.Runtime().Resume(ctx, r.ID)
if err != nil {
@@ -646,12 +646,11 @@ func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empt
if err != nil {
return nil, err
}
- if p == nil {
- return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
- }
if err := p.Kill(ctx, r.Signal, r.All); err != nil {
- return nil, errdefs.ToGRPC(err)
+ log.L.Debugf("Kill failed: %v", err)
+ return nil, utils.ErrToGRPC(err)
}
+ log.L.Debugf("Kill succeeded")
return empty, nil
}
@@ -661,7 +660,7 @@ func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.Pi
pids, err := s.getContainerPids(ctx, r.ID)
if err != nil {
- return nil, errdefs.ToGRPC(err)
+ return nil, utils.ErrToGRPC(err)
}
var processes []*task.ProcessInfo
for _, pid := range pids {
@@ -707,7 +706,7 @@ func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*type
// Checkpoint checkpoints the container.
func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) {
log.L.Debugf("Checkpoint, id: %s", r.ID)
- return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+ return empty, utils.ErrToGRPC(errdefs.ErrNotImplemented)
}
// Connect returns shim information such as the shim's pid.
@@ -738,9 +737,9 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.
log.L.Debugf("Stats, id: %s", r.ID)
if s.task == nil {
log.L.Debugf("Stats error, id: %s: container not created", r.ID)
- return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
}
- stats, err := s.task.Runtime().Stats(ctx, s.id)
+ stats, err := s.task.Stats(ctx, s.id)
if err != nil {
log.L.Debugf("Stats error, id: %s: %v", r.ID, err)
return nil, err
@@ -812,7 +811,7 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.
// Update updates a running container.
func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) {
- return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+ return empty, utils.ErrToGRPC(errdefs.ErrNotImplemented)
}
// Wait waits for a process to exit.
@@ -821,17 +820,17 @@ func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.Wa
p, err := s.getProcess(r.ExecID)
if err != nil {
+ log.L.Debugf("Wait failed to find process: %v", err)
return nil, err
}
- if p == nil {
- return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
- }
p.Wait()
- return &taskAPI.WaitResponse{
+ res := &taskAPI.WaitResponse{
ExitStatus: uint32(p.ExitStatus()),
ExitedAt: p.ExitedAt(),
- }, nil
+ }
+ log.L.Debugf("Wait succeeded, response: %+v", res)
+ return res, nil
}
func (s *service) processExits(ctx context.Context) {
@@ -848,10 +847,7 @@ func (s *service) checkProcesses(ctx context.Context, e proc.Exit) {
if ip, ok := p.(*proc.Init); ok {
// Ensure all children are killed.
log.L.Debugf("Container init process exited, killing all container processes")
- if err := ip.KillAll(ctx); err != nil {
- log.G(ctx).WithError(err).WithField("id", ip.ID()).
- Error("failed to kill init's children")
- }
+ ip.KillAll(ctx)
}
p.SetExited(e.Status)
s.events <- &events.TaskExit{
@@ -909,12 +905,17 @@ func (s *service) forward(ctx context.Context, publisher shim.Publisher) {
func (s *service) getProcess(execID string) (process.Process, error) {
s.mu.Lock()
defer s.mu.Unlock()
+
if execID == "" {
+ if s.task == nil {
+ return nil, utils.ErrToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
return s.task, nil
}
+
p := s.processes[execID]
if p == nil {
- return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID)
+ return nil, utils.ErrToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID)
}
return p, nil
}
diff --git a/pkg/shim/utils/BUILD b/pkg/shim/utils/BUILD
index 54a0aabb7..2eb82f63c 100644
--- a/pkg/shim/utils/BUILD
+++ b/pkg/shim/utils/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "utils",
srcs = [
"annotations.go",
+ "errors.go",
"utils.go",
"volumes.go",
],
@@ -14,14 +15,23 @@ go_library(
"//shim:__subpackages__",
],
deps = [
+ "@com_github_containerd_containerd//errdefs:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_google_grpc//codes:go_default_library",
+ "@org_golang_google_grpc//status:go_default_library",
],
)
go_test(
name = "utils_test",
size = "small",
- srcs = ["volumes_test.go"],
+ srcs = [
+ "errors_test.go",
+ "volumes_test.go",
+ ],
library = ":utils",
- deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"],
+ deps = [
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
)
diff --git a/pkg/shim/utils/errors.go b/pkg/shim/utils/errors.go
new file mode 100644
index 000000000..971d68c36
--- /dev/null
+++ b/pkg/shim/utils/errors.go
@@ -0,0 +1,74 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/containerd/containerd/errdefs"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+// ErrToGRPC wraps containerd's ToGRPC error mapper which depends on
+// github.com/pkg/errors to work correctly. Once we upgrade to containerd v1.4,
+// this function can go away and we can use errdefs.ToGRPC directly instead.
+//
+// TODO(gvisor.dev/issue/6232): Remove after upgrading to containerd v1.4
+func ErrToGRPC(err error) error {
+ return errToGRPCMsg(err, err.Error())
+}
+
+// ErrToGRPCf maps the error to grpc error codes, assembling the formatting
+// string and combining it with the target error string.
+//
+// TODO(gvisor.dev/issue/6232): Remove after upgrading to containerd v1.4
+func ErrToGRPCf(err error, format string, args ...interface{}) error {
+ formatted := fmt.Sprintf(format, args...)
+ msg := fmt.Sprintf("%s: %s", formatted, err.Error())
+ return errToGRPCMsg(err, msg)
+}
+
+func errToGRPCMsg(err error, msg string) error {
+ if err == nil {
+ return nil
+ }
+ if _, ok := status.FromError(err); ok {
+ return err
+ }
+
+ switch {
+ case errors.Is(err, errdefs.ErrInvalidArgument):
+ return status.Errorf(codes.InvalidArgument, msg)
+ case errors.Is(err, errdefs.ErrNotFound):
+ return status.Errorf(codes.NotFound, msg)
+ case errors.Is(err, errdefs.ErrAlreadyExists):
+ return status.Errorf(codes.AlreadyExists, msg)
+ case errors.Is(err, errdefs.ErrFailedPrecondition):
+ return status.Errorf(codes.FailedPrecondition, msg)
+ case errors.Is(err, errdefs.ErrUnavailable):
+ return status.Errorf(codes.Unavailable, msg)
+ case errors.Is(err, errdefs.ErrNotImplemented):
+ return status.Errorf(codes.Unimplemented, msg)
+ case errors.Is(err, context.Canceled):
+ return status.Errorf(codes.Canceled, msg)
+ case errors.Is(err, context.DeadlineExceeded):
+ return status.Errorf(codes.DeadlineExceeded, msg)
+ }
+
+ return errdefs.ToGRPC(err)
+}
diff --git a/pkg/shim/utils/errors_test.go b/pkg/shim/utils/errors_test.go
new file mode 100644
index 000000000..0a8fe34c8
--- /dev/null
+++ b/pkg/shim/utils/errors_test.go
@@ -0,0 +1,50 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/containerd/containerd/errdefs"
+)
+
+func TestGRPCRoundTripsErrors(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ err error
+ test func(err error) bool
+ }{
+ {
+ name: "passthrough",
+ err: errdefs.ErrNotFound,
+ test: errdefs.IsNotFound,
+ },
+ {
+ name: "wrapped",
+ err: fmt.Errorf("oh no: %w", errdefs.ErrNotFound),
+ test: errdefs.IsNotFound,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ if err := errdefs.FromGRPC(ErrToGRPC(tc.err)); !tc.test(err) {
+ t.Errorf("errToGRPC got %+v", err)
+ }
+ if err := errdefs.FromGRPC(ErrToGRPCf(tc.err, "testing %s", "123")); !tc.test(err) {
+ t.Errorf("errToGRPCf got %+v", err)
+ }
+ })
+ }
+}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 8b3a11c64..73791b456 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -1,5 +1,4 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template")
package(
default_visibility = ["//:sandbox"],
@@ -8,45 +7,6 @@ package(
exports_files(["LICENSE"])
-go_template(
- name = "generic_atomicptr",
- srcs = ["generic_atomicptr_unsafe.go"],
- types = [
- "Value",
- ],
-)
-
-go_template(
- name = "generic_atomicptrmap",
- srcs = ["generic_atomicptrmap_unsafe.go"],
- opt_consts = [
- "ShardOrder",
- ],
- opt_types = [
- "Hasher",
- ],
- types = [
- "Key",
- "Value",
- ],
- deps = [
- ":sync",
- "//pkg/gohacks",
- ],
-)
-
-go_template(
- name = "generic_seqatomic",
- srcs = ["generic_seqatomic_unsafe.go"],
- types = [
- "Value",
- ],
- deps = [
- ":sync",
- "//pkg/gohacks",
- ],
-)
-
go_library(
name = "sync",
srcs = [
diff --git a/pkg/sync/README.md b/pkg/sync/README.md
index 2183c4e20..be1a01f08 100644
--- a/pkg/sync/README.md
+++ b/pkg/sync/README.md
@@ -1,4 +1,4 @@
-# Syncutil
+# sync
This package provides additional synchronization primitives not provided by the
Go stdlib 'sync' package. It is partially derived from the upstream 'sync'
diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptr/BUILD
index e97553254..a6a7f01ac 100644
--- a/pkg/sync/atomicptrtest/BUILD
+++ b/pkg/sync/atomicptr/BUILD
@@ -1,14 +1,23 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
+go_template(
+ name = "generic_atomicptr",
+ srcs = ["generic_atomicptr_unsafe.go"],
+ types = [
+ "Value",
+ ],
+ visibility = ["//:sandbox"],
+)
+
go_template_instance(
name = "atomicptr_int",
out = "atomicptr_int_unsafe.go",
package = "atomicptr",
suffix = "Int",
- template = "//pkg/sync:generic_atomicptr",
+ template = ":generic_atomicptr",
types = {
"Value": "int",
},
diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptr/atomicptr_test.go
index 8fdc5112e..8fdc5112e 100644
--- a/pkg/sync/atomicptrtest/atomicptr_test.go
+++ b/pkg/sync/atomicptr/atomicptr_test.go
diff --git a/pkg/sync/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
index 82b6df18c..82b6df18c 100644
--- a/pkg/sync/generic_atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmap/BUILD
index 3f71ae97d..b0e218c79 100644
--- a/pkg/sync/atomicptrmaptest/BUILD
+++ b/pkg/sync/atomicptrmap/BUILD
@@ -1,17 +1,36 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
+go_template(
+ name = "generic_atomicptrmap",
+ srcs = ["generic_atomicptrmap_unsafe.go"],
+ opt_consts = [
+ "ShardOrder",
+ ],
+ opt_types = [
+ "Hasher",
+ ],
+ types = [
+ "Key",
+ "Value",
+ ],
+ deps = [
+ "//pkg/gohacks",
+ "//pkg/sync",
+ ],
+)
+
go_template_instance(
name = "test_atomicptrmap",
out = "test_atomicptrmap_unsafe.go",
package = "atomicptrmap",
prefix = "test",
- template = "//pkg/sync:generic_atomicptrmap",
+ template = ":generic_atomicptrmap",
types = {
"Key": "int64",
"Value": "testValue",
@@ -27,7 +46,7 @@ go_template_instance(
package = "atomicptrmap",
prefix = "test",
suffix = "Sharded",
- template = "//pkg/sync:generic_atomicptrmap",
+ template = ":generic_atomicptrmap",
types = {
"Key": "int64",
"Value": "testValue",
diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap.go b/pkg/sync/atomicptrmap/atomicptrmap.go
index 867821ce9..867821ce9 100644
--- a/pkg/sync/atomicptrmaptest/atomicptrmap.go
+++ b/pkg/sync/atomicptrmap/atomicptrmap.go
diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmap/atomicptrmap_test.go
index 75a9997ef..75a9997ef 100644
--- a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go
+++ b/pkg/sync/atomicptrmap/atomicptrmap_test.go
diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go
index 3e98cb309..3e98cb309 100644
--- a/pkg/sync/generic_atomicptrmap_unsafe.go
+++ b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go
diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomic/BUILD
index 5f9164117..60f79ab54 100644
--- a/pkg/sync/seqatomictest/BUILD
+++ b/pkg/sync/seqatomic/BUILD
@@ -1,14 +1,27 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
+go_template(
+ name = "generic_seqatomic",
+ srcs = ["generic_seqatomic_unsafe.go"],
+ types = [
+ "Value",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ ":sync",
+ "//pkg/gohacks",
+ ],
+)
+
go_template_instance(
name = "seqatomic_int",
out = "seqatomic_int_unsafe.go",
package = "seqatomic",
suffix = "Int",
- template = "//pkg/sync:generic_seqatomic",
+ template = ":generic_seqatomic",
types = {
"Value": "int",
},
diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go
index 9578c9c52..9578c9c52 100644
--- a/pkg/sync/generic_seqatomic_unsafe.go
+++ b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go
diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomic/seqatomic_test.go
index 2c4568b07..2c4568b07 100644
--- a/pkg/sync/seqatomictest/seqatomic_test.go
+++ b/pkg/sync/seqatomic/seqatomic_test.go
diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD
index 9cc9e3bf2..ceee494fc 100644
--- a/pkg/syserr/BUILD
+++ b/pkg/syserr/BUILD
@@ -11,7 +11,9 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/abi/linux",
+ "//pkg/abi/linux/errno",
+ "//pkg/errors",
+ "//pkg/errors/linuxerr",
"//pkg/syserror",
"//pkg/tcpip",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 90be24e15..eb44f1254 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -17,7 +17,7 @@ package syserr
import (
"fmt"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -25,33 +25,33 @@ import (
// Mapping for tcpip.Error types.
var (
- ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), linux.EINVAL)
- ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), linux.ENODEV)
- ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), linux.ENODEV)
- ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), linux.ENOPROTOOPT)
- ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), linux.EEXIST)
- ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), linux.EEXIST)
- ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), linux.EINVAL)
- ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), linux.EINVAL)
- ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), linux.EALREADY)
- ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), linux.EAGAIN)
- ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE)
- ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL)
- ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE)
- ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), linux.NOERRNO)
- ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT)
- ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE)
- ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS)
- ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), linux.EDESTADDRREQ)
- ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), linux.EOPNOTSUPP)
- ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), linux.ENOTTY)
- ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), linux.ENOENT)
- ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), linux.EINVAL)
- ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), linux.EACCES)
- ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM)
- ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT)
- ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL)
- ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL)
+ ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), errno.EINVAL)
+ ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), errno.ENODEV)
+ ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), errno.ENODEV)
+ ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), errno.ENOPROTOOPT)
+ ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), errno.EEXIST)
+ ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), errno.EEXIST)
+ ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), errno.EINVAL)
+ ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), errno.EINVAL)
+ ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), errno.EALREADY)
+ ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), errno.EAGAIN)
+ ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), errno.EADDRINUSE)
+ ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), errno.EADDRNOTAVAIL)
+ ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), errno.EPIPE)
+ ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), errno.NOERRNO)
+ ErrTimeout = New((&tcpip.ErrTimeout{}).String(), errno.ETIMEDOUT)
+ ErrAborted = New((&tcpip.ErrAborted{}).String(), errno.EPIPE)
+ ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), errno.EINPROGRESS)
+ ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), errno.EDESTADDRREQ)
+ ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), errno.EOPNOTSUPP)
+ ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), errno.ENOTTY)
+ ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), errno.ENOENT)
+ ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), errno.EINVAL)
+ ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), errno.EACCES)
+ ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), errno.EPERM)
+ ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), errno.EFAULT)
+ ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), errno.EINVAL)
+ ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), errno.EINVAL)
)
// TranslateNetstackError converts an error from the tcpip package to a sentry
diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go
index d70521f32..765128ff0 100644
--- a/pkg/syserr/syserr.go
+++ b/pkg/syserr/syserr.go
@@ -21,7 +21,9 @@ import (
"fmt"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
+ "gvisor.dev/gvisor/pkg/errors"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -31,34 +33,33 @@ type Error struct {
message string
// noTranslation indicates that this Error cannot be translated to a
- // linux.Errno.
+ // errno.Errno.
noTranslation bool
- // errno is the linux.Errno this Error should be translated to.
- errno linux.Errno
+ // errno is the errno.Errno this Error should be translated to.
+ errno errno.Errno
}
// New creates a new Error and adds a translation for it.
//
// New must only be called at init.
-func New(message string, linuxTranslation linux.Errno) *Error {
+func New(message string, linuxTranslation errno.Errno) *Error {
err := &Error{message: message, errno: linuxTranslation}
// TODO(b/34162363): Remove this.
- errno := linuxTranslation
- if errno < 0 || int(errno) >= len(linuxBackwardsTranslations) {
- panic(fmt.Sprint("invalid errno: ", errno))
+ if int(err.errno) >= len(linuxBackwardsTranslations) {
+ panic(fmt.Sprint("invalid errno: ", err.errno))
}
- e := error(unix.Errno(errno))
+ e := error(unix.Errno(err.errno))
// syserror.ErrWouldBlock gets translated to syserror.EWOULDBLOCK and
// enables proper blocking semantics. This should temporary address the
// class of blocking bugs that keep popping up with the current state of
// the error space.
- if e == syserror.EWOULDBLOCK {
+ if err.errno == linuxerr.EWOULDBLOCK.Errno() {
e = syserror.ErrWouldBlock
}
- linuxBackwardsTranslations[errno] = linuxBackwardsTranslation{err: e, ok: true}
+ linuxBackwardsTranslations[err.errno] = linuxBackwardsTranslation{err: e, ok: true}
return err
}
@@ -69,7 +70,7 @@ func New(message string, linuxTranslation linux.Errno) *Error {
// NewDynamic should only be used sparingly and not be used for static error
// messages. Errors with static error messages should be declared with New as
// global variables.
-func NewDynamic(message string, linuxTranslation linux.Errno) *Error {
+func NewDynamic(message string, linuxTranslation errno.Errno) *Error {
return &Error{message: message, errno: linuxTranslation}
}
@@ -82,7 +83,7 @@ func NewWithoutTranslation(message string) *Error {
return &Error{message: message, noTranslation: true}
}
-func newWithHost(message string, linuxTranslation linux.Errno, hostErrno unix.Errno) *Error {
+func newWithHost(message string, linuxTranslation errno.Errno, hostErrno unix.Errno) *Error {
e := New(message, linuxTranslation)
addLinuxHostTranslation(hostErrno, e)
return e
@@ -114,19 +115,19 @@ func (e *Error) ToError() error {
if e.noTranslation {
panic(fmt.Sprintf("error %q does not support translation", e.message))
}
- errno := int(e.errno)
- if errno == linux.NOERRNO {
+ err := int(e.errno)
+ if err == errno.NOERRNO {
return nil
}
- if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok {
- panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno))
+ if err >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[err].ok {
+ panic(fmt.Sprintf("unknown error %q (%d)", e.message, err))
}
- return linuxBackwardsTranslations[errno].err
+ return linuxBackwardsTranslations[err].err
}
// ToLinux converts the Error to a Linux ABI error that can be returned to the
// application.
-func (e *Error) ToLinux() linux.Errno {
+func (e *Error) ToLinux() errno.Errno {
if e.noTranslation {
panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message))
}
@@ -138,137 +139,137 @@ func (e *Error) ToLinux() linux.Errno {
// Some of the errors should be replaced with package specific errors and
// others should be removed entirely.
var (
- ErrNotPermitted = newWithHost("operation not permitted", linux.EPERM, unix.EPERM)
- ErrNoFileOrDir = newWithHost("no such file or directory", linux.ENOENT, unix.ENOENT)
- ErrNoProcess = newWithHost("no such process", linux.ESRCH, unix.ESRCH)
- ErrInterrupted = newWithHost("interrupted system call", linux.EINTR, unix.EINTR)
- ErrIO = newWithHost("I/O error", linux.EIO, unix.EIO)
- ErrDeviceOrAddress = newWithHost("no such device or address", linux.ENXIO, unix.ENXIO)
- ErrTooManyArgs = newWithHost("argument list too long", linux.E2BIG, unix.E2BIG)
- ErrEcec = newWithHost("exec format error", linux.ENOEXEC, unix.ENOEXEC)
- ErrBadFD = newWithHost("bad file number", linux.EBADF, unix.EBADF)
- ErrNoChild = newWithHost("no child processes", linux.ECHILD, unix.ECHILD)
- ErrTryAgain = newWithHost("try again", linux.EAGAIN, unix.EAGAIN)
- ErrNoMemory = newWithHost("out of memory", linux.ENOMEM, unix.ENOMEM)
- ErrPermissionDenied = newWithHost("permission denied", linux.EACCES, unix.EACCES)
- ErrBadAddress = newWithHost("bad address", linux.EFAULT, unix.EFAULT)
- ErrNotBlockDevice = newWithHost("block device required", linux.ENOTBLK, unix.ENOTBLK)
- ErrBusy = newWithHost("device or resource busy", linux.EBUSY, unix.EBUSY)
- ErrExists = newWithHost("file exists", linux.EEXIST, unix.EEXIST)
- ErrCrossDeviceLink = newWithHost("cross-device link", linux.EXDEV, unix.EXDEV)
- ErrNoDevice = newWithHost("no such device", linux.ENODEV, unix.ENODEV)
- ErrNotDir = newWithHost("not a directory", linux.ENOTDIR, unix.ENOTDIR)
- ErrIsDir = newWithHost("is a directory", linux.EISDIR, unix.EISDIR)
- ErrInvalidArgument = newWithHost("invalid argument", linux.EINVAL, unix.EINVAL)
- ErrFileTableOverflow = newWithHost("file table overflow", linux.ENFILE, unix.ENFILE)
- ErrTooManyOpenFiles = newWithHost("too many open files", linux.EMFILE, unix.EMFILE)
- ErrNotTTY = newWithHost("not a typewriter", linux.ENOTTY, unix.ENOTTY)
- ErrTestFileBusy = newWithHost("text file busy", linux.ETXTBSY, unix.ETXTBSY)
- ErrFileTooBig = newWithHost("file too large", linux.EFBIG, unix.EFBIG)
- ErrNoSpace = newWithHost("no space left on device", linux.ENOSPC, unix.ENOSPC)
- ErrIllegalSeek = newWithHost("illegal seek", linux.ESPIPE, unix.ESPIPE)
- ErrReadOnlyFS = newWithHost("read-only file system", linux.EROFS, unix.EROFS)
- ErrTooManyLinks = newWithHost("too many links", linux.EMLINK, unix.EMLINK)
- ErrBrokenPipe = newWithHost("broken pipe", linux.EPIPE, unix.EPIPE)
- ErrDomain = newWithHost("math argument out of domain of func", linux.EDOM, unix.EDOM)
- ErrRange = newWithHost("math result not representable", linux.ERANGE, unix.ERANGE)
- ErrDeadlock = newWithHost("resource deadlock would occur", linux.EDEADLOCK, unix.EDEADLOCK)
- ErrNameTooLong = newWithHost("file name too long", linux.ENAMETOOLONG, unix.ENAMETOOLONG)
- ErrNoLocksAvailable = newWithHost("no record locks available", linux.ENOLCK, unix.ENOLCK)
- ErrInvalidSyscall = newWithHost("invalid system call number", linux.ENOSYS, unix.ENOSYS)
- ErrDirNotEmpty = newWithHost("directory not empty", linux.ENOTEMPTY, unix.ENOTEMPTY)
- ErrLinkLoop = newWithHost("too many symbolic links encountered", linux.ELOOP, unix.ELOOP)
- ErrNoMessage = newWithHost("no message of desired type", linux.ENOMSG, unix.ENOMSG)
- ErrIdentifierRemoved = newWithHost("identifier removed", linux.EIDRM, unix.EIDRM)
- ErrChannelOutOfRange = newWithHost("channel number out of range", linux.ECHRNG, unix.ECHRNG)
- ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", linux.EL2NSYNC, unix.EL2NSYNC)
- ErrLevelThreeHalted = newWithHost("level 3 halted", linux.EL3HLT, unix.EL3HLT)
- ErrLevelThreeReset = newWithHost("level 3 reset", linux.EL3RST, unix.EL3RST)
- ErrLinkNumberOutOfRange = newWithHost("link number out of range", linux.ELNRNG, unix.ELNRNG)
- ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", linux.EUNATCH, unix.EUNATCH)
- ErrNoCSIAvailable = newWithHost("no CSI structure available", linux.ENOCSI, unix.ENOCSI)
- ErrLevelTwoHalted = newWithHost("level 2 halted", linux.EL2HLT, unix.EL2HLT)
- ErrInvalidExchange = newWithHost("invalid exchange", linux.EBADE, unix.EBADE)
- ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", linux.EBADR, unix.EBADR)
- ErrExchangeFull = newWithHost("exchange full", linux.EXFULL, unix.EXFULL)
- ErrNoAnode = newWithHost("no anode", linux.ENOANO, unix.ENOANO)
- ErrInvalidRequestCode = newWithHost("invalid request code", linux.EBADRQC, unix.EBADRQC)
- ErrInvalidSlot = newWithHost("invalid slot", linux.EBADSLT, unix.EBADSLT)
- ErrBadFontFile = newWithHost("bad font file format", linux.EBFONT, unix.EBFONT)
- ErrNotStream = newWithHost("device not a stream", linux.ENOSTR, unix.ENOSTR)
- ErrNoDataAvailable = newWithHost("no data available", linux.ENODATA, unix.ENODATA)
- ErrTimerExpired = newWithHost("timer expired", linux.ETIME, unix.ETIME)
- ErrStreamsResourceDepleted = newWithHost("out of streams resources", linux.ENOSR, unix.ENOSR)
- ErrMachineNotOnNetwork = newWithHost("machine is not on the network", linux.ENONET, unix.ENONET)
- ErrPackageNotInstalled = newWithHost("package not installed", linux.ENOPKG, unix.ENOPKG)
- ErrIsRemote = newWithHost("object is remote", linux.EREMOTE, unix.EREMOTE)
- ErrNoLink = newWithHost("link has been severed", linux.ENOLINK, unix.ENOLINK)
- ErrAdvertise = newWithHost("advertise error", linux.EADV, unix.EADV)
- ErrSRMount = newWithHost("srmount error", linux.ESRMNT, unix.ESRMNT)
- ErrSendCommunication = newWithHost("communication error on send", linux.ECOMM, unix.ECOMM)
- ErrProtocol = newWithHost("protocol error", linux.EPROTO, unix.EPROTO)
- ErrMultihopAttempted = newWithHost("multihop attempted", linux.EMULTIHOP, unix.EMULTIHOP)
- ErrRFS = newWithHost("RFS specific error", linux.EDOTDOT, unix.EDOTDOT)
- ErrInvalidDataMessage = newWithHost("not a data message", linux.EBADMSG, unix.EBADMSG)
- ErrOverflow = newWithHost("value too large for defined data type", linux.EOVERFLOW, unix.EOVERFLOW)
- ErrNetworkNameNotUnique = newWithHost("name not unique on network", linux.ENOTUNIQ, unix.ENOTUNIQ)
- ErrFDInBadState = newWithHost("file descriptor in bad state", linux.EBADFD, unix.EBADFD)
- ErrRemoteAddressChanged = newWithHost("remote address changed", linux.EREMCHG, unix.EREMCHG)
- ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", linux.ELIBACC, unix.ELIBACC)
- ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", linux.ELIBBAD, unix.ELIBBAD)
- ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", linux.ELIBSCN, unix.ELIBSCN)
- ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", linux.ELIBMAX, unix.ELIBMAX)
- ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", linux.ELIBEXEC, unix.ELIBEXEC)
- ErrIllegalByteSequence = newWithHost("illegal byte sequence", linux.EILSEQ, unix.EILSEQ)
- ErrShouldRestart = newWithHost("interrupted system call should be restarted", linux.ERESTART, unix.ERESTART)
- ErrStreamPipe = newWithHost("streams pipe error", linux.ESTRPIPE, unix.ESTRPIPE)
- ErrTooManyUsers = newWithHost("too many users", linux.EUSERS, unix.EUSERS)
- ErrNotASocket = newWithHost("socket operation on non-socket", linux.ENOTSOCK, unix.ENOTSOCK)
- ErrDestinationAddressRequired = newWithHost("destination address required", linux.EDESTADDRREQ, unix.EDESTADDRREQ)
- ErrMessageTooLong = newWithHost("message too long", linux.EMSGSIZE, unix.EMSGSIZE)
- ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", linux.EPROTOTYPE, unix.EPROTOTYPE)
- ErrProtocolNotAvailable = newWithHost("protocol not available", linux.ENOPROTOOPT, unix.ENOPROTOOPT)
- ErrProtocolNotSupported = newWithHost("protocol not supported", linux.EPROTONOSUPPORT, unix.EPROTONOSUPPORT)
- ErrSocketNotSupported = newWithHost("socket type not supported", linux.ESOCKTNOSUPPORT, unix.ESOCKTNOSUPPORT)
- ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", linux.EOPNOTSUPP, unix.EOPNOTSUPP)
- ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", linux.EPFNOSUPPORT, unix.EPFNOSUPPORT)
- ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", linux.EAFNOSUPPORT, unix.EAFNOSUPPORT)
- ErrAddressInUse = newWithHost("address already in use", linux.EADDRINUSE, unix.EADDRINUSE)
- ErrAddressNotAvailable = newWithHost("cannot assign requested address", linux.EADDRNOTAVAIL, unix.EADDRNOTAVAIL)
- ErrNetworkDown = newWithHost("network is down", linux.ENETDOWN, unix.ENETDOWN)
- ErrNetworkUnreachable = newWithHost("network is unreachable", linux.ENETUNREACH, unix.ENETUNREACH)
- ErrNetworkReset = newWithHost("network dropped connection because of reset", linux.ENETRESET, unix.ENETRESET)
- ErrConnectionAborted = newWithHost("software caused connection abort", linux.ECONNABORTED, unix.ECONNABORTED)
- ErrConnectionReset = newWithHost("connection reset by peer", linux.ECONNRESET, unix.ECONNRESET)
- ErrNoBufferSpace = newWithHost("no buffer space available", linux.ENOBUFS, unix.ENOBUFS)
- ErrAlreadyConnected = newWithHost("transport endpoint is already connected", linux.EISCONN, unix.EISCONN)
- ErrNotConnected = newWithHost("transport endpoint is not connected", linux.ENOTCONN, unix.ENOTCONN)
- ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", linux.ESHUTDOWN, unix.ESHUTDOWN)
- ErrTooManyRefs = newWithHost("too many references: cannot splice", linux.ETOOMANYREFS, unix.ETOOMANYREFS)
- ErrTimedOut = newWithHost("connection timed out", linux.ETIMEDOUT, unix.ETIMEDOUT)
- ErrConnectionRefused = newWithHost("connection refused", linux.ECONNREFUSED, unix.ECONNREFUSED)
- ErrHostDown = newWithHost("host is down", linux.EHOSTDOWN, unix.EHOSTDOWN)
- ErrNoRoute = newWithHost("no route to host", linux.EHOSTUNREACH, unix.EHOSTUNREACH)
- ErrAlreadyInProgress = newWithHost("operation already in progress", linux.EALREADY, unix.EALREADY)
- ErrInProgress = newWithHost("operation now in progress", linux.EINPROGRESS, unix.EINPROGRESS)
- ErrStaleFileHandle = newWithHost("stale file handle", linux.ESTALE, unix.ESTALE)
- ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", linux.EUCLEAN, unix.EUCLEAN)
- ErrIsNamedFile = newWithHost("is a named type file", linux.ENOTNAM, unix.ENOTNAM)
- ErrRemoteIO = newWithHost("remote I/O error", linux.EREMOTEIO, unix.EREMOTEIO)
- ErrQuotaExceeded = newWithHost("quota exceeded", linux.EDQUOT, unix.EDQUOT)
- ErrNoMedium = newWithHost("no medium found", linux.ENOMEDIUM, unix.ENOMEDIUM)
- ErrWrongMediumType = newWithHost("wrong medium type", linux.EMEDIUMTYPE, unix.EMEDIUMTYPE)
- ErrCanceled = newWithHost("operation canceled", linux.ECANCELED, unix.ECANCELED)
- ErrNoKey = newWithHost("required key not available", linux.ENOKEY, unix.ENOKEY)
- ErrKeyExpired = newWithHost("key has expired", linux.EKEYEXPIRED, unix.EKEYEXPIRED)
- ErrKeyRevoked = newWithHost("key has been revoked", linux.EKEYREVOKED, unix.EKEYREVOKED)
- ErrKeyRejected = newWithHost("key was rejected by service", linux.EKEYREJECTED, unix.EKEYREJECTED)
- ErrOwnerDied = newWithHost("owner died", linux.EOWNERDEAD, unix.EOWNERDEAD)
- ErrNotRecoverable = newWithHost("state not recoverable", linux.ENOTRECOVERABLE, unix.ENOTRECOVERABLE)
+ ErrNotPermitted = newWithHost("operation not permitted", errno.EPERM, unix.EPERM)
+ ErrNoFileOrDir = newWithHost("no such file or directory", errno.ENOENT, unix.ENOENT)
+ ErrNoProcess = newWithHost("no such process", errno.ESRCH, unix.ESRCH)
+ ErrInterrupted = newWithHost("interrupted system call", errno.EINTR, unix.EINTR)
+ ErrIO = newWithHost("I/O error", errno.EIO, unix.EIO)
+ ErrDeviceOrAddress = newWithHost("no such device or address", errno.ENXIO, unix.ENXIO)
+ ErrTooManyArgs = newWithHost("argument list too long", errno.E2BIG, unix.E2BIG)
+ ErrEcec = newWithHost("exec format error", errno.ENOEXEC, unix.ENOEXEC)
+ ErrBadFD = newWithHost("bad file number", errno.EBADF, unix.EBADF)
+ ErrNoChild = newWithHost("no child processes", errno.ECHILD, unix.ECHILD)
+ ErrTryAgain = newWithHost("try again", errno.EAGAIN, unix.EAGAIN)
+ ErrNoMemory = newWithHost("out of memory", errno.ENOMEM, unix.ENOMEM)
+ ErrPermissionDenied = newWithHost("permission denied", errno.EACCES, unix.EACCES)
+ ErrBadAddress = newWithHost("bad address", errno.EFAULT, unix.EFAULT)
+ ErrNotBlockDevice = newWithHost("block device required", errno.ENOTBLK, unix.ENOTBLK)
+ ErrBusy = newWithHost("device or resource busy", errno.EBUSY, unix.EBUSY)
+ ErrExists = newWithHost("file exists", errno.EEXIST, unix.EEXIST)
+ ErrCrossDeviceLink = newWithHost("cross-device link", errno.EXDEV, unix.EXDEV)
+ ErrNoDevice = newWithHost("no such device", errno.ENODEV, unix.ENODEV)
+ ErrNotDir = newWithHost("not a directory", errno.ENOTDIR, unix.ENOTDIR)
+ ErrIsDir = newWithHost("is a directory", errno.EISDIR, unix.EISDIR)
+ ErrInvalidArgument = newWithHost("invalid argument", errno.EINVAL, unix.EINVAL)
+ ErrFileTableOverflow = newWithHost("file table overflow", errno.ENFILE, unix.ENFILE)
+ ErrTooManyOpenFiles = newWithHost("too many open files", errno.EMFILE, unix.EMFILE)
+ ErrNotTTY = newWithHost("not a typewriter", errno.ENOTTY, unix.ENOTTY)
+ ErrTestFileBusy = newWithHost("text file busy", errno.ETXTBSY, unix.ETXTBSY)
+ ErrFileTooBig = newWithHost("file too large", errno.EFBIG, unix.EFBIG)
+ ErrNoSpace = newWithHost("no space left on device", errno.ENOSPC, unix.ENOSPC)
+ ErrIllegalSeek = newWithHost("illegal seek", errno.ESPIPE, unix.ESPIPE)
+ ErrReadOnlyFS = newWithHost("read-only file system", errno.EROFS, unix.EROFS)
+ ErrTooManyLinks = newWithHost("too many links", errno.EMLINK, unix.EMLINK)
+ ErrBrokenPipe = newWithHost("broken pipe", errno.EPIPE, unix.EPIPE)
+ ErrDomain = newWithHost("math argument out of domain of func", errno.EDOM, unix.EDOM)
+ ErrRange = newWithHost("math result not representable", errno.ERANGE, unix.ERANGE)
+ ErrDeadlock = newWithHost("resource deadlock would occur", errno.EDEADLOCK, unix.EDEADLOCK)
+ ErrNameTooLong = newWithHost("file name too long", errno.ENAMETOOLONG, unix.ENAMETOOLONG)
+ ErrNoLocksAvailable = newWithHost("no record locks available", errno.ENOLCK, unix.ENOLCK)
+ ErrInvalidSyscall = newWithHost("invalid system call number", errno.ENOSYS, unix.ENOSYS)
+ ErrDirNotEmpty = newWithHost("directory not empty", errno.ENOTEMPTY, unix.ENOTEMPTY)
+ ErrLinkLoop = newWithHost("too many symbolic links encountered", errno.ELOOP, unix.ELOOP)
+ ErrNoMessage = newWithHost("no message of desired type", errno.ENOMSG, unix.ENOMSG)
+ ErrIdentifierRemoved = newWithHost("identifier removed", errno.EIDRM, unix.EIDRM)
+ ErrChannelOutOfRange = newWithHost("channel number out of range", errno.ECHRNG, unix.ECHRNG)
+ ErrLevelTwoNotSynced = newWithHost("level 2 not synchronized", errno.EL2NSYNC, unix.EL2NSYNC)
+ ErrLevelThreeHalted = newWithHost("level 3 halted", errno.EL3HLT, unix.EL3HLT)
+ ErrLevelThreeReset = newWithHost("level 3 reset", errno.EL3RST, unix.EL3RST)
+ ErrLinkNumberOutOfRange = newWithHost("link number out of range", errno.ELNRNG, unix.ELNRNG)
+ ErrProtocolDriverNotAttached = newWithHost("protocol driver not attached", errno.EUNATCH, unix.EUNATCH)
+ ErrNoCSIAvailable = newWithHost("no CSI structure available", errno.ENOCSI, unix.ENOCSI)
+ ErrLevelTwoHalted = newWithHost("level 2 halted", errno.EL2HLT, unix.EL2HLT)
+ ErrInvalidExchange = newWithHost("invalid exchange", errno.EBADE, unix.EBADE)
+ ErrInvalidRequestDescriptor = newWithHost("invalid request descriptor", errno.EBADR, unix.EBADR)
+ ErrExchangeFull = newWithHost("exchange full", errno.EXFULL, unix.EXFULL)
+ ErrNoAnode = newWithHost("no anode", errno.ENOANO, unix.ENOANO)
+ ErrInvalidRequestCode = newWithHost("invalid request code", errno.EBADRQC, unix.EBADRQC)
+ ErrInvalidSlot = newWithHost("invalid slot", errno.EBADSLT, unix.EBADSLT)
+ ErrBadFontFile = newWithHost("bad font file format", errno.EBFONT, unix.EBFONT)
+ ErrNotStream = newWithHost("device not a stream", errno.ENOSTR, unix.ENOSTR)
+ ErrNoDataAvailable = newWithHost("no data available", errno.ENODATA, unix.ENODATA)
+ ErrTimerExpired = newWithHost("timer expired", errno.ETIME, unix.ETIME)
+ ErrStreamsResourceDepleted = newWithHost("out of streams resources", errno.ENOSR, unix.ENOSR)
+ ErrMachineNotOnNetwork = newWithHost("machine is not on the network", errno.ENONET, unix.ENONET)
+ ErrPackageNotInstalled = newWithHost("package not installed", errno.ENOPKG, unix.ENOPKG)
+ ErrIsRemote = newWithHost("object is remote", errno.EREMOTE, unix.EREMOTE)
+ ErrNoLink = newWithHost("link has been severed", errno.ENOLINK, unix.ENOLINK)
+ ErrAdvertise = newWithHost("advertise error", errno.EADV, unix.EADV)
+ ErrSRMount = newWithHost("srmount error", errno.ESRMNT, unix.ESRMNT)
+ ErrSendCommunication = newWithHost("communication error on send", errno.ECOMM, unix.ECOMM)
+ ErrProtocol = newWithHost("protocol error", errno.EPROTO, unix.EPROTO)
+ ErrMultihopAttempted = newWithHost("multihop attempted", errno.EMULTIHOP, unix.EMULTIHOP)
+ ErrRFS = newWithHost("RFS specific error", errno.EDOTDOT, unix.EDOTDOT)
+ ErrInvalidDataMessage = newWithHost("not a data message", errno.EBADMSG, unix.EBADMSG)
+ ErrOverflow = newWithHost("value too large for defined data type", errno.EOVERFLOW, unix.EOVERFLOW)
+ ErrNetworkNameNotUnique = newWithHost("name not unique on network", errno.ENOTUNIQ, unix.ENOTUNIQ)
+ ErrFDInBadState = newWithHost("file descriptor in bad state", errno.EBADFD, unix.EBADFD)
+ ErrRemoteAddressChanged = newWithHost("remote address changed", errno.EREMCHG, unix.EREMCHG)
+ ErrSharedLibraryInaccessible = newWithHost("can not access a needed shared library", errno.ELIBACC, unix.ELIBACC)
+ ErrCorruptedSharedLibrary = newWithHost("accessing a corrupted shared library", errno.ELIBBAD, unix.ELIBBAD)
+ ErrLibSectionCorrupted = newWithHost(".lib section in a.out corrupted", errno.ELIBSCN, unix.ELIBSCN)
+ ErrTooManySharedLibraries = newWithHost("attempting to link in too many shared libraries", errno.ELIBMAX, unix.ELIBMAX)
+ ErrSharedLibraryExeced = newWithHost("cannot exec a shared library directly", errno.ELIBEXEC, unix.ELIBEXEC)
+ ErrIllegalByteSequence = newWithHost("illegal byte sequence", errno.EILSEQ, unix.EILSEQ)
+ ErrShouldRestart = newWithHost("interrupted system call should be restarted", errno.ERESTART, unix.ERESTART)
+ ErrStreamPipe = newWithHost("streams pipe error", errno.ESTRPIPE, unix.ESTRPIPE)
+ ErrTooManyUsers = newWithHost("too many users", errno.EUSERS, unix.EUSERS)
+ ErrNotASocket = newWithHost("socket operation on non-socket", errno.ENOTSOCK, unix.ENOTSOCK)
+ ErrDestinationAddressRequired = newWithHost("destination address required", errno.EDESTADDRREQ, unix.EDESTADDRREQ)
+ ErrMessageTooLong = newWithHost("message too long", errno.EMSGSIZE, unix.EMSGSIZE)
+ ErrWrongProtocolForSocket = newWithHost("protocol wrong type for socket", errno.EPROTOTYPE, unix.EPROTOTYPE)
+ ErrProtocolNotAvailable = newWithHost("protocol not available", errno.ENOPROTOOPT, unix.ENOPROTOOPT)
+ ErrProtocolNotSupported = newWithHost("protocol not supported", errno.EPROTONOSUPPORT, unix.EPROTONOSUPPORT)
+ ErrSocketNotSupported = newWithHost("socket type not supported", errno.ESOCKTNOSUPPORT, unix.ESOCKTNOSUPPORT)
+ ErrEndpointOperation = newWithHost("operation not supported on transport endpoint", errno.EOPNOTSUPP, unix.EOPNOTSUPP)
+ ErrProtocolFamilyNotSupported = newWithHost("protocol family not supported", errno.EPFNOSUPPORT, unix.EPFNOSUPPORT)
+ ErrAddressFamilyNotSupported = newWithHost("address family not supported by protocol", errno.EAFNOSUPPORT, unix.EAFNOSUPPORT)
+ ErrAddressInUse = newWithHost("address already in use", errno.EADDRINUSE, unix.EADDRINUSE)
+ ErrAddressNotAvailable = newWithHost("cannot assign requested address", errno.EADDRNOTAVAIL, unix.EADDRNOTAVAIL)
+ ErrNetworkDown = newWithHost("network is down", errno.ENETDOWN, unix.ENETDOWN)
+ ErrNetworkUnreachable = newWithHost("network is unreachable", errno.ENETUNREACH, unix.ENETUNREACH)
+ ErrNetworkReset = newWithHost("network dropped connection because of reset", errno.ENETRESET, unix.ENETRESET)
+ ErrConnectionAborted = newWithHost("software caused connection abort", errno.ECONNABORTED, unix.ECONNABORTED)
+ ErrConnectionReset = newWithHost("connection reset by peer", errno.ECONNRESET, unix.ECONNRESET)
+ ErrNoBufferSpace = newWithHost("no buffer space available", errno.ENOBUFS, unix.ENOBUFS)
+ ErrAlreadyConnected = newWithHost("transport endpoint is already connected", errno.EISCONN, unix.EISCONN)
+ ErrNotConnected = newWithHost("transport endpoint is not connected", errno.ENOTCONN, unix.ENOTCONN)
+ ErrShutdown = newWithHost("cannot send after transport endpoint shutdown", errno.ESHUTDOWN, unix.ESHUTDOWN)
+ ErrTooManyRefs = newWithHost("too many references: cannot splice", errno.ETOOMANYREFS, unix.ETOOMANYREFS)
+ ErrTimedOut = newWithHost("connection timed out", errno.ETIMEDOUT, unix.ETIMEDOUT)
+ ErrConnectionRefused = newWithHost("connection refused", errno.ECONNREFUSED, unix.ECONNREFUSED)
+ ErrHostDown = newWithHost("host is down", errno.EHOSTDOWN, unix.EHOSTDOWN)
+ ErrNoRoute = newWithHost("no route to host", errno.EHOSTUNREACH, unix.EHOSTUNREACH)
+ ErrAlreadyInProgress = newWithHost("operation already in progress", errno.EALREADY, unix.EALREADY)
+ ErrInProgress = newWithHost("operation now in progress", errno.EINPROGRESS, unix.EINPROGRESS)
+ ErrStaleFileHandle = newWithHost("stale file handle", errno.ESTALE, unix.ESTALE)
+ ErrStructureNeedsCleaning = newWithHost("structure needs cleaning", errno.EUCLEAN, unix.EUCLEAN)
+ ErrIsNamedFile = newWithHost("is a named type file", errno.ENOTNAM, unix.ENOTNAM)
+ ErrRemoteIO = newWithHost("remote I/O error", errno.EREMOTEIO, unix.EREMOTEIO)
+ ErrQuotaExceeded = newWithHost("quota exceeded", errno.EDQUOT, unix.EDQUOT)
+ ErrNoMedium = newWithHost("no medium found", errno.ENOMEDIUM, unix.ENOMEDIUM)
+ ErrWrongMediumType = newWithHost("wrong medium type", errno.EMEDIUMTYPE, unix.EMEDIUMTYPE)
+ ErrCanceled = newWithHost("operation canceled", errno.ECANCELED, unix.ECANCELED)
+ ErrNoKey = newWithHost("required key not available", errno.ENOKEY, unix.ENOKEY)
+ ErrKeyExpired = newWithHost("key has expired", errno.EKEYEXPIRED, unix.EKEYEXPIRED)
+ ErrKeyRevoked = newWithHost("key has been revoked", errno.EKEYREVOKED, unix.EKEYREVOKED)
+ ErrKeyRejected = newWithHost("key was rejected by service", errno.EKEYREJECTED, unix.EKEYREJECTED)
+ ErrOwnerDied = newWithHost("owner died", errno.EOWNERDEAD, unix.EOWNERDEAD)
+ ErrNotRecoverable = newWithHost("state not recoverable", errno.ENOTRECOVERABLE, unix.ENOTRECOVERABLE)
// ErrWouldBlock translates to EWOULDBLOCK which is the same as EAGAIN
// on Linux.
- ErrWouldBlock = New("operation would block", linux.EWOULDBLOCK)
+ ErrWouldBlock = New("operation would block", errno.EWOULDBLOCK)
)
// FromError converts a generic error to an *Error.
@@ -281,6 +282,11 @@ func FromError(err error) *Error {
if errno, ok := err.(unix.Errno); ok {
return FromHost(errno)
}
+
+ if linuxErr, ok := err.(*errors.Error); ok {
+ return FromHost(unix.Errno(linuxErr.Errno()))
+ }
+
if errno, ok := syserror.TranslateError(err); ok {
return FromHost(errno)
}
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
index 7d2f5adf6..76bee5a64 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/syserror/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,12 +8,3 @@ go_library(
visibility = ["//visibility:public"],
deps = ["@org_golang_x_sys//unix:go_default_library"],
)
-
-go_test(
- name = "syserror_test",
- srcs = ["syserror_test.go"],
- deps = [
- ":syserror",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 56b621357..721f62903 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -26,9 +26,7 @@ import (
// The following variables have the same meaning as their syscall equivalent.
var (
- E2BIG = error(unix.E2BIG)
EACCES = error(unix.EACCES)
- EADDRINUSE = error(unix.EADDRINUSE)
EAGAIN = error(unix.EAGAIN)
EBADF = error(unix.EBADF)
EBADFD = error(unix.EBADFD)
@@ -43,7 +41,6 @@ var (
EFBIG = error(unix.EFBIG)
EIDRM = error(unix.EIDRM)
EINTR = error(unix.EINTR)
- EINVAL = error(unix.EINVAL)
EIO = error(unix.EIO)
EISDIR = error(unix.EISDIR)
ELIBBAD = error(unix.ELIBBAD)
diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go
deleted file mode 100644
index c141e5f6e..000000000
--- a/pkg/syserror/syserror_test.go
+++ /dev/null
@@ -1,132 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package syserror_test
-
-import (
- "errors"
- "testing"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-var globalError error
-
-func BenchmarkAssignErrno(b *testing.B) {
- for i := b.N; i > 0; i-- {
- globalError = unix.EINVAL
- }
-}
-
-func BenchmarkAssignError(b *testing.B) {
- for i := b.N; i > 0; i-- {
- globalError = syserror.EINVAL
- }
-}
-
-func BenchmarkCompareErrno(b *testing.B) {
- globalError = unix.EAGAIN
- j := 0
- for i := b.N; i > 0; i-- {
- if globalError == unix.EINVAL {
- j++
- }
- }
-}
-
-func BenchmarkCompareError(b *testing.B) {
- globalError = syserror.EAGAIN
- j := 0
- for i := b.N; i > 0; i-- {
- if globalError == syserror.EINVAL {
- j++
- }
- }
-}
-
-func BenchmarkSwitchErrno(b *testing.B) {
- globalError = unix.EPERM
- j := 0
- for i := b.N; i > 0; i-- {
- switch globalError {
- case unix.EINVAL:
- j += 1
- case unix.EINTR:
- j += 2
- case unix.EAGAIN:
- j += 3
- }
- }
-}
-
-func BenchmarkSwitchError(b *testing.B) {
- globalError = syserror.EPERM
- j := 0
- for i := b.N; i > 0; i-- {
- switch globalError {
- case syserror.EINVAL:
- j += 1
- case syserror.EINTR:
- j += 2
- case syserror.EAGAIN:
- j += 3
- }
- }
-}
-
-type translationTestTable struct {
- fn string
- errIn error
- syscallErrorIn unix.Errno
- expectedBool bool
- expectedTranslation unix.Errno
-}
-
-func TestErrorTranslation(t *testing.T) {
- myError := errors.New("My test error")
- myError2 := errors.New("Another test error")
- testTable := []translationTestTable{
- {"TranslateError", myError, 0, false, 0},
- {"TranslateError", myError2, 0, false, 0},
- {"AddErrorTranslation", myError, unix.EAGAIN, true, 0},
- {"AddErrorTranslation", myError, unix.EAGAIN, false, 0},
- {"AddErrorTranslation", myError, unix.EPERM, false, 0},
- {"TranslateError", myError, 0, true, unix.EAGAIN},
- {"TranslateError", myError2, 0, false, 0},
- {"AddErrorTranslation", myError2, unix.EPERM, true, 0},
- {"AddErrorTranslation", myError2, unix.EPERM, false, 0},
- {"AddErrorTranslation", myError2, unix.EAGAIN, false, 0},
- {"TranslateError", myError, 0, true, unix.EAGAIN},
- {"TranslateError", myError2, 0, true, unix.EPERM},
- }
- for _, tt := range testTable {
- switch tt.fn {
- case "TranslateError":
- err, ok := syserror.TranslateError(tt.errIn)
- if ok != tt.expectedBool {
- t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
- } else if err != tt.expectedTranslation {
- t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation)
- }
- case "AddErrorTranslation":
- ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn)
- if ok != tt.expectedBool {
- t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
- }
- default:
- t.Fatalf("Unknown function %v", tt.fn)
- }
- }
-}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index ea46c30da..f00cfd0f5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -39,12 +39,29 @@ go_library(
deps_test(
name = "netstack_deps_test",
allowed = [
+ # gVisor deps.
+ "//pkg/atomicbitops",
+ "//pkg/buffer",
+ "//pkg/context",
+ "//pkg/gohacks",
+ "//pkg/goid",
+ "//pkg/ilist",
+ "//pkg/linewriter",
+ "//pkg/log",
+ "//pkg/rand",
+ "//pkg/sleep",
+ "//pkg/state",
+ "//pkg/state/wire",
+ "//pkg/sync",
+ "//pkg/waiter",
+
+ # Other deps.
"@com_github_google_btree//:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
"@org_golang_x_time//rate:go_default_library",
],
allowed_prefixes = [
- "//",
+ "//pkg/tcpip",
"@org_golang_x_sys//internal/unsafeheader",
],
targets = [
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 18e6cc3cd..e0dfe5813 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -50,7 +50,7 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
ipv4 := header.IPv4(b)
if !ipv4.IsValid(len(b)) {
- t.Error("Not a valid IPv4 packet")
+ t.Fatalf("Not a valid IPv4 packet: %x", ipv4)
}
if !ipv4.IsChecksumValid() {
@@ -72,7 +72,7 @@ func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
ipv6 := header.IPv6(b)
if !ipv6.IsValid(len(b)) {
- t.Error("Not a valid IPv6 packet")
+ t.Fatalf("Not a valid IPv6 packet: %x", ipv6)
}
for _, f := range checkers {
@@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
if !ok {
return
}
- opts := []byte(tcp.Options())
+ opts := tcp.Options()
limit := len(opts)
foundTS := false
tsVal := uint32(0)
@@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
}
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does
-// not contain any SACK blocks in the TCP options.
-func TCPNoSACKBlockChecker() TransportChecker {
- return TCPSACKBlockChecker(nil)
-}
-
// TCPSACKBlockChecker creates a checker that verifies that the segment does
// contain the specified SACK blocks in the TCP options.
func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
@@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
var gotSACKBlocks []header.SACKBlock
- opts := []byte(tcp.Options())
+ opts := tcp.Options()
limit := len(opts)
for i := 0; i < limit; {
switch opts[i] {
diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go
index fb819d7a8..9f8f51647 100644
--- a/pkg/tcpip/faketime/faketime.go
+++ b/pkg/tcpip/faketime/faketime.go
@@ -29,14 +29,14 @@ type NullClock struct{}
var _ tcpip.Clock = (*NullClock)(nil)
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (*NullClock) NowNanoseconds() int64 {
- return 0
+// Now implements tcpip.Clock.Now.
+func (*NullClock) Now() time.Time {
+ return time.Time{}
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (*NullClock) NowMonotonic() int64 {
- return 0
+func (*NullClock) NowMonotonic() tcpip.MonotonicTime {
+ return tcpip.MonotonicTime{}
}
// AfterFunc implements tcpip.Clock.AfterFunc.
@@ -118,16 +118,17 @@ func NewManualClock() *ManualClock {
var _ tcpip.Clock = (*ManualClock)(nil)
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (mc *ManualClock) NowNanoseconds() int64 {
+// Now implements tcpip.Clock.Now.
+func (mc *ManualClock) Now() time.Time {
mc.mu.RLock()
defer mc.mu.RUnlock()
- return mc.mu.now.UnixNano()
+ return mc.mu.now
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (mc *ManualClock) NowMonotonic() int64 {
- return mc.NowNanoseconds()
+func (mc *ManualClock) NowMonotonic() tcpip.MonotonicTime {
+ var mt tcpip.MonotonicTime
+ return mt.Add(mc.Now().Sub(time.Unix(0, 0)))
}
// AfterFunc implements tcpip.Clock.AfterFunc.
@@ -218,6 +219,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) {
}
}
+// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current
+// time.
+func (mc *ManualClock) RunImmediatelyScheduledJobs() {
+ mc.Advance(0)
+}
+
// Advance executes all work that have been scheduled to execute within d from
// the current time. Blocks until all work has completed execution.
func (mc *ManualClock) Advance(d time.Duration) {
diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go
index c2704df2c..fd2bb470a 100644
--- a/pkg/tcpip/faketime/faketime_test.go
+++ b/pkg/tcpip/faketime/faketime_test.go
@@ -26,7 +26,7 @@ func TestManualClockAdvance(t *testing.T) {
clock := faketime.NewManualClock()
start := clock.NowMonotonic()
clock.Advance(timeout)
- if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want {
+ if got, want := clock.NowMonotonic().Sub(start), timeout; got != want {
t.Errorf("got = %d, want = %d", got, want)
}
}
@@ -87,7 +87,7 @@ func TestManualClockAfterFunc(t *testing.T) {
if got, want := counter2, test.wantCounter2; got != want {
t.Errorf("got counter2 = %d, want = %d", got, want)
}
- if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want {
+ if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want {
t.Errorf("got elapsed = %d, want = %d", got, want)
}
})
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 6aa9acfa8..e2c85e220 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -18,6 +18,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
+
+// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
+// checksum.
+//
+// The value MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ return ChecksumCombine(xsum, ChecksumCombine(new, ^old))
+}
+
+// checksumUpdate2ByteAlignedAddress updates an address in a calculated
+// checksum.
+//
+// The addresses must have the same length and must contain an even number
+// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
+ const uint16Bytes = 2
+
+ if len(old) != len(new) {
+ panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new)))
+ }
+
+ if len(old)%uint16Bytes != 0 {
+ panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old)))
+ }
+
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ for len(old) != 0 {
+ // Convert the 2 byte sequences to uint16 values then apply the increment
+ // update.
+ xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1]))
+ old = old[uint16Bytes:]
+ new = new[uint16Bytes:]
+ }
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index d267dabd0..3445511f4 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -23,6 +23,7 @@ import (
"sync"
"testing"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) {
})
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
+
+func randomAddress(size int) tcpip.Address {
+ s := make([]byte, size)
+ for i := 0; i < size; i++ {
+ s[i] = byte(rand.Uint32())
+ }
+ return tcpip.Address(s)
+}
+
+func TestChecksummableNetworkUpdateAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ update func(header.IPv4, tcpip.Address)
+ }{
+ {
+ name: "SetSourceAddressWithChecksumUpdate",
+ update: header.IPv4.SetSourceAddressWithChecksumUpdate,
+ },
+ {
+ name: "SetDestinationAddressWithChecksumUpdate",
+ update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for i := 0; i < 1000; i++ {
+ var origBytes [header.IPv4MinimumSize]byte
+ header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
+ TOS: 1,
+ TotalLength: header.IPv4MinimumSize,
+ ID: 2,
+ Flags: 3,
+ FragmentOffset: 4,
+ TTL: 5,
+ Protocol: 6,
+ Checksum: 0,
+ SrcAddr: randomAddress(header.IPv4AddressSize),
+ DstAddr: randomAddress(header.IPv4AddressSize),
+ })
+
+ addr := randomAddress(header.IPv4AddressSize)
+
+ bytesCopy := origBytes
+ h := header.IPv4(bytesCopy[:])
+ origXSum := h.CalculateChecksum()
+ h.SetChecksum(^origXSum)
+
+ test.update(h, addr)
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+ want := h.CalculateChecksum()
+ if got != want {
+ t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
+ }
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePort(t *testing.T) {
+ // The fields in the pseudo header is not tested here so we just use 0.
+ const pseudoHeaderXSum = 0
+
+ tests := []struct {
+ name string
+ transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.TCP(make([]byte, header.TCPMinimumSize))
+ h.Encode(&header.TCPFields{
+ SrcPort: src,
+ DstPort: dst,
+ SeqNum: 1,
+ AckNum: 2,
+ DataOffset: header.TCPMinimumSize,
+ Flags: 3,
+ WindowSize: 4,
+ Checksum: 0,
+ UrgentPointer: 5,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.UDP(make([]byte, header.UDPMinimumSize))
+ h.Encode(&header.UDPFields{
+ SrcPort: src,
+ DstPort: dst,
+ Length: 0,
+ Checksum: 0,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ origSrcPort := uint16(rand.Uint32())
+ origDstPort := uint16(rand.Uint32())
+ newPort := uint16(rand.Uint32())
+
+ t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range []struct {
+ name string
+ update func(header.ChecksummableTransport)
+ }{
+ {
+ name: "Source port",
+ update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
+ },
+ {
+ name: "Destination port",
+ update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
+ },
+ } {
+ t.Run(subTest.name, func(t *testing.T) {
+ h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
+ subTest.update(h)
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+
+ if want := calcXSum(pseudoHeaderXSum); got != want {
+ h, _ := test.transportHdr(origSrcPort, origDstPort)
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
+ const addressSize = 6
+
+ tests := []struct {
+ name string
+ transportHdr func() header.ChecksummableTransport
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ permanent := randomAddress(addressSize)
+ old := randomAddress(addressSize)
+ new := randomAddress(addressSize)
+
+ t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, fullChecksum := range []bool{true, false} {
+ t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
+ initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
+ if fullChecksum {
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ initialXSum = ^initialXSum
+ }
+
+ h := test.transportHdr()
+ h.SetChecksum(initialXSum)
+ h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
+
+ got := h.Checksum()
+ if fullChecksum {
+ got = ^got
+ }
+ if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
index 861cbbb70..3a41adfc4 100644
--- a/pkg/tcpip/header/interfaces.go
+++ b/pkg/tcpip/header/interfaces.go
@@ -53,6 +53,31 @@ type Transport interface {
Payload() []byte
}
+// ChecksummableTransport is a Transport that supports checksumming.
+type ChecksummableTransport interface {
+ Transport
+
+ // SetSourcePortWithChecksumUpdate sets the source port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetSourcePortWithChecksumUpdate(port uint16)
+
+ // SetDestinationPortWithChecksumUpdate sets the destination port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetDestinationPortWithChecksumUpdate(port uint16)
+
+ // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
+ // updated address in the pseudo header.
+ //
+ // If fullChecksum is true, the receiver's checksum field is assumed to hold a
+ // fully calculated checksum. Otherwise, it is assumed to hold a partially
+ // calculated checksum which only reflects the pseudo header.
+ UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
+}
+
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
@@ -90,3 +115,16 @@ type Network interface {
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
+
+// ChecksummableNetwork is a Network that supports checksumming.
+type ChecksummableNetwork interface {
+ Network
+
+ // SetSourceAddressAndChecksum sets the source address and updates the
+ // checksum to reflect the new address.
+ SetSourceAddressWithChecksumUpdate(tcpip.Address)
+
+ // SetDestinationAddressAndChecksum sets the destination address and
+ // updates the checksum to reflect the new address.
+ SetDestinationAddressWithChecksumUpdate(tcpip.Address)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 2be21ec75..dcc549c7b 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -17,6 +17,7 @@ package header
import (
"encoding/binary"
"fmt"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -181,7 +182,7 @@ const (
// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined
// by RFC 3927 section 1.
var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00"))
+ subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", "\xff\xff\x00\x00")
if err != nil {
panic(err)
}
@@ -191,7 +192,7 @@ var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet {
// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as
// defined by RFC 5771 section 4.
var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00"))
+ subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", "\xff\xff\xff\x00")
if err != nil {
panic(err)
}
@@ -304,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new))
+ b.SetSourceAddress(new)
+}
+
+// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new))
+ b.SetDestinationAddress(new)
+}
+
// padIPv4OptionsLength returns the total length for IPv4 options of length l
// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
@@ -572,7 +585,7 @@ func (o *IPv4OptionGeneric) Type() IPv4OptionType {
func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) }
// Contents implements IPv4Option.
-func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) }
+func (o *IPv4OptionGeneric) Contents() []byte { return *o }
// IPv4OptionIterator is an iterator pointing to a specific IP option
// at any point of time. It also holds information as to a new options buffer
@@ -610,7 +623,7 @@ func (i *IPv4OptionIterator) InitReplacement(option IPv4Option) IPv4Options {
// RemainingBuffer returns the remaining (unused) part of the new option buffer,
// into which a new option may be written.
func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options {
- return IPv4Options(i.newOptions[i.writePoint:])
+ return i.newOptions[i.writePoint:]
}
// ConsumeBuffer marks a portion of the new buffer as used.
@@ -813,9 +826,12 @@ const (
// ipv4TimestampTime provides the current time as specified in RFC 791.
func ipv4TimestampTime(clock tcpip.Clock) uint32 {
- const millisecondsPerDay = 24 * 3600 * 1000
- const nanoPerMilli = 1000000
- return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay)
+ // Per RFC 791 page 21:
+ // The Timestamp is a right-justified, 32-bit timestamp in
+ // milliseconds since midnight UT.
+ now := clock.Now().UTC()
+ midnight := now.Truncate(24 * time.Hour)
+ return uint32(now.Sub(midnight).Milliseconds())
}
// IP Timestamp option fields.
@@ -843,7 +859,7 @@ func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestam
func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) }
// Contents implements IPv4Option.
-func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) }
+func (ts *IPv4OptionTimestamp) Contents() []byte { return *ts }
// Pointer returns the pointer field in the IP Timestamp option.
func (ts *IPv4OptionTimestamp) Pointer() uint8 {
@@ -947,7 +963,7 @@ func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecord
func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
// Contents implements IPv4Option.
-func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
+func (rr *IPv4OptionRecordRoute) Contents() []byte { return *rr }
// Router Alert option specific related constants.
//
@@ -992,7 +1008,7 @@ func (*IPv4OptionRouterAlert) Type() IPv4OptionType { return IPv4OptionRouterAle
func (ra *IPv4OptionRouterAlert) Size() uint8 { return uint8(len(*ra)) }
// Contents implements IPv4Option.
-func (ra *IPv4OptionRouterAlert) Contents() []byte { return []byte(*ra) }
+func (ra *IPv4OptionRouterAlert) Contents() []byte { return *ra }
// Value returns the value of the IPv4OptionRouterAlert.
func (ra *IPv4OptionRouterAlert) Value() uint16 {
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 3d1bccd15..b1f39e6e6 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -77,12 +77,12 @@ const (
// ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag
// field in the flags byte within an NDPPrefixInformation.
- ndpPrefixInformationOnLinkFlagMask = (1 << 7)
+ ndpPrefixInformationOnLinkFlagMask = 1 << 7
// ndpPrefixInformationAutoAddrConfFlagMask is the mask of the
// Autonomous Address-Configuration flag field in the flags byte within
// an NDPPrefixInformation.
- ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6)
+ ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6
// ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1
// field in the flags byte within an NDPPrefixInformation.
@@ -148,15 +148,10 @@ const (
// NDP option. That is, the length field for NDP options is in units of
// 8 octets, as per RFC 4861 section 4.6.
lengthByteUnits = 8
-)
-var (
// NDPInfiniteLifetime is a value that represents infinity for the
// 4-byte lifetime fields found in various NDP options. Its value is
// (2^32 - 1)s = 4294967295s.
- //
- // This is a variable instead of a constant so that tests can change
- // this value to a smaller value. It should only be modified by tests.
NDPInfiniteLifetime = time.Second * math.MaxUint32
)
@@ -451,7 +446,7 @@ func (o NDPNonceOption) String() string {
// Nonce returns the nonce value this option holds.
func (o NDPNonceOption) Nonce() []byte {
- return []byte(o)
+ return o
}
// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
index bf7610863..7e2f0c797 100644
--- a/pkg/tcpip/header/ndp_router_advert.go
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -19,12 +19,72 @@ import (
"time"
)
+// NDPRoutePreference is the preference values for default routers or
+// more-specific routes.
+//
+// As per RFC 4191 section 2.1,
+//
+// Default router preferences and preferences for more-specific routes
+// are encoded the same way.
+//
+// Preference values are encoded as a two-bit signed integer, as
+// follows:
+//
+// 01 High
+// 00 Medium (default)
+// 11 Low
+// 10 Reserved - MUST NOT be sent
+//
+// Note that implementations can treat the value as a two-bit signed
+// integer.
+//
+// Having just three values reinforces that they are not metrics and
+// more values do not appear to be necessary for reasonable scenarios.
+type NDPRoutePreference uint8
+
+const (
+ // HighRoutePreference indicates a high preference, as per
+ // RFC 4191 section 2.1.
+ HighRoutePreference NDPRoutePreference = 0b01
+
+ // MediumRoutePreference indicates a medium preference, as per
+ // RFC 4191 section 2.1.
+ //
+ // This is the default preference value.
+ MediumRoutePreference = 0b00
+
+ // LowRoutePreference indicates a low preference, as per
+ // RFC 4191 section 2.1.
+ LowRoutePreference = 0b11
+
+ // ReservedRoutePreference is a reserved preference value, as per
+ // RFC 4191 section 2.1.
+ //
+ // It MUST NOT be sent.
+ ReservedRoutePreference = 0b10
+)
+
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
-// See RFC 4861 section 4.2 for more details.
+// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details.
type NDPRouterAdvert []byte
+// As per RFC 4191 section 2.2,
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Code | Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Reachable Time |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Retrans Timer |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options ...
+// +-+-+-+-+-+-+-+-+-+-+-+-
const (
// NDPRAMinimumSize is the minimum size of a valid NDP Router
// Advertisement message (body of an ICMPv6 packet).
@@ -47,6 +107,14 @@ const (
// within the bit-field/flags byte of an NDPRouterAdvert.
ndpRAOtherConfFlagMask = (1 << 6)
+ // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceShift = 3
+
+ // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift)
+
// ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
// field within an NDPRouterAdvert.
ndpRARouterLifetimeOffset = 2
@@ -80,6 +148,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
}
+// DefaultRouterPreference returns the Default Router Preference field.
+func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference {
+ return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift)
+}
+
// RouterLifetime returns the lifetime associated with the default router. A
// value of 0 means the source of the Router Advertisement is not a default
// router and SHOULD NOT appear on the default router list. Note, a value of 0
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 1b5093e58..8fd1f7d13 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -126,36 +126,83 @@ func TestNDPNeighborAdvert(t *testing.T) {
}
func TestNDPRouterAdvert(t *testing.T) {
- b := []byte{
- 64, 128, 1, 2,
- 3, 4, 5, 6,
- 7, 8, 9, 10,
+ tests := []struct {
+ hopLimit uint8
+ managedFlag, otherConfFlag bool
+ prf NDPRoutePreference
+ routerLifetimeS uint16
+ reachableTimeMS, retransTimerMS uint32
+ }{
+ {
+ hopLimit: 1,
+ managedFlag: false,
+ otherConfFlag: true,
+ prf: HighRoutePreference,
+ routerLifetimeS: 2,
+ reachableTimeMS: 3,
+ retransTimerMS: 4,
+ },
+ {
+ hopLimit: 64,
+ managedFlag: true,
+ otherConfFlag: false,
+ prf: LowRoutePreference,
+ routerLifetimeS: 258,
+ reachableTimeMS: 78492,
+ retransTimerMS: 13213,
+ },
}
- ra := NDPRouterAdvert(b)
+ for i, test := range tests {
+ t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
+ flags := uint8(0)
+ if test.managedFlag {
+ flags |= 1 << 7
+ }
+ if test.otherConfFlag {
+ flags |= 1 << 6
+ }
+ flags |= uint8(test.prf) << 3
- if got := ra.CurrHopLimit(); got != 64 {
- t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
- }
+ b := []byte{
+ test.hopLimit, flags, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+ binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS)
+ binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS)
+ binary.BigEndian.PutUint32(b[8:], test.retransTimerMS)
- if got := ra.ManagedAddrConfFlag(); !got {
- t.Errorf("got ManagedAddrConfFlag = false, want = true")
- }
+ ra := NDPRouterAdvert(b)
- if got := ra.OtherConfFlag(); got {
- t.Errorf("got OtherConfFlag = true, want = false")
- }
+ if got := ra.CurrHopLimit(); got != test.hopLimit {
+ t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit)
+ }
- if got, want := ra.RouterLifetime(), time.Second*258; got != want {
- t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
- }
+ if got := ra.ManagedAddrConfFlag(); got != test.managedFlag {
+ t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag)
+ }
- if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
- t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
- }
+ if got := ra.OtherConfFlag(); got != test.otherConfFlag {
+ t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag)
+ }
+
+ if got := ra.DefaultRouterPreference(); got != test.prf {
+ t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf)
+ }
- if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
- t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
+ if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want {
+ t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want {
+ t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want {
+ t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want)
+ }
+ })
}
}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 0df517000..a75e51a28 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -48,6 +48,16 @@ const (
// TCPFlags is the dedicated type for TCP flags.
type TCPFlags uint8
+// Intersects returns true iff there are flags common to both f and o.
+func (f TCPFlags) Intersects(o TCPFlags) bool {
+ return f&o != 0
+}
+
+// Contains returns true iff all the flags in o are contained within f.
+func (f TCPFlags) Contains(o TCPFlags) bool {
+ return f&o == o
+}
+
// String implements Stringer.String.
func (f TCPFlags) String() string {
flagsStr := []byte("FSRPAU")
@@ -380,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
b.SetChecksum(^checksum)
}
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
+
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index ae9d167ff..f69d53314 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
+
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index ef9126deb..f26c857eb 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -288,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
}
// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d971194e6..1d0163823 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,7 +14,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index bddb1d0a2..1b56d2b72 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,11 +41,9 @@ package fdbased
import (
"fmt"
- "math"
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -139,6 +137,20 @@ type endpoint struct {
// gsoKind is the supported kind of GSO.
gsoKind stack.SupportedGSO
+
+ // maxSyscallHeaderBytes has the same meaning as
+ // Options.MaxSyscallHeaderBytes.
+ maxSyscallHeaderBytes uintptr
+
+ // writevMaxIovs is the maximum number of iovecs that may be passed to
+ // rawfile.NonBlockingWriteIovec, as possibly limited by
+ // maxSyscallHeaderBytes. (No analogous limit is defined for
+ // rawfile.NonBlockingSendMMsg, since in that case the maximum number of
+ // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch
+ // encounters a packet whose iovec count is limited by
+ // maxSyscallHeaderBytes, it falls back to writing the packet using writev
+ // via WritePacket.)
+ writevMaxIovs int
}
// Options specify the details about the fd-based endpoint to be created.
@@ -187,6 +199,11 @@ type Options struct {
// RXChecksumOffload if true, indicates that this endpoints capability
// set should include CapabilityRXChecksumOffload.
RXChecksumOffload bool
+
+ // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes
+ // of struct iovec, msghdr, and mmsghdr that may be passed by each host
+ // system call.
+ MaxSyscallHeaderBytes int
}
// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT
@@ -196,8 +213,12 @@ type Options struct {
// option for an FD with a fanoutID already in use by another FD for a different
// NIC will return an EINVAL.
//
+// Since fanoutID must be unique within the network namespace, we start with
+// the PID to avoid collisions. The only way to be sure of avoiding collisions
+// is to run in a new network namespace.
+//
// Must be accessed using atomic operations.
-var fanoutID int32 = 0
+var fanoutID int32 = int32(unix.Getpid())
// New creates a new fd-based endpoint.
//
@@ -232,14 +253,25 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
+ if opts.MaxSyscallHeaderBytes < 0 {
+ return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative")
+ }
+
e := &endpoint{
- fds: opts.FDs,
- mtu: opts.MTU,
- caps: caps,
- closed: opts.ClosedFunc,
- addr: opts.Address,
- hdrSize: hdrSize,
- packetDispatchMode: opts.PacketDispatchMode,
+ fds: opts.FDs,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ packetDispatchMode: opts.PacketDispatchMode,
+ maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes),
+ writevMaxIovs: rawfile.MaxIovs,
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs {
+ e.writevMaxIovs = max
+ }
}
// Increment fanoutID to ensure that we don't re-use the same fanoutID for
@@ -292,11 +324,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // See: PACKET_FANOUT_MAX in net/packet/internal.h
- const packetFanoutMax = 1 << 16
- if fID > packetFanoutMax {
- return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16)
- }
// Enable PACKET_FANOUT mode if the underlying socket is of type
// AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
// prevent gvisor from receiving fragmented packets and the host does the
@@ -317,7 +344,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
//
// See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881
const fanoutType = unix.PACKET_FANOUT_HASH
- fanoutArg := int(fID) | fanoutType<<16
+ fanoutArg := (int(fID) & 0xffff) | fanoutType<<16
if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
}
@@ -472,9 +499,8 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
}
- var builder iovec.Builder
-
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
+ var vnetHdrBuf []byte
if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
@@ -496,71 +522,123 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
+ vnetHdrBuf = vnetHdr.marshal()
+ }
- vnetHdrBuf := vnetHdr.marshal()
- builder.Add(vnetHdrBuf)
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > e.writevMaxIovs {
+ numIovecs = e.writevMaxIovs
}
- for _, v := range pkt.Views() {
- builder.Add(v)
+ // Allocate small iovec arrays on the stack.
+ var iovecsArr [8]unix.Iovec
+ iovecs := iovecsArr[:0]
+ if numIovecs > len(iovecsArr) {
+ iovecs = make([]unix.Iovec, 0, numIovecs)
}
- return rawfile.NonBlockingWriteIovec(fd, builder.Build())
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
+ }
+ return rawfile.NonBlockingWriteIovec(fd, iovecs)
}
-func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) {
+func (e *endpoint) sendBatch(batchFD int, pkts []*stack.PacketBuffer) (int, tcpip.Error) {
// Send a batch of packets through batchFD.
- mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
- for _, pkt := range batch {
- if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
- }
+ mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts))
+ packets := 0
+ for packets < len(pkts) {
+ mmsgHdrs := mmsgHdrsStorage
+ batch := pkts[packets:]
+ syscallHeaderBytes := uintptr(0)
+ for _, pkt := range batch {
+ if e.hdrSize > 0 {
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
- var vnetHdrBuf []byte
- if e.gsoKind == stack.HWGSOSupported {
- vnetHdr := virtioNetHdr{}
- if pkt.GSOOptions.Type != stack.GSONone {
- vnetHdr.hdrLen = uint16(pkt.HeaderSize())
- if pkt.GSOOptions.NeedsCsum {
- vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
- vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
- vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
- }
- if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
- switch pkt.GSOOptions.Type {
- case stack.GSOTCPv4:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
- case stack.GSOTCPv6:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
- default:
- panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ var vnetHdrBuf []byte
+ if e.gsoKind == stack.HWGSOSupported {
+ vnetHdr := virtioNetHdr{}
+ if pkt.GSOOptions.Type != stack.GSONone {
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
+ if pkt.GSOOptions.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
+ vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
+ }
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
+ switch pkt.GSOOptions.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ }
+ vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
- vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
+ vnetHdrBuf = vnetHdr.marshal()
}
- vnetHdrBuf = vnetHdr.marshal()
- }
- var builder iovec.Builder
- builder.Add(vnetHdrBuf)
- for _, v := range pkt.Views() {
- builder.Add(v)
- }
- iovecs := builder.Build()
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > rawfile.MaxIovs {
+ numIovecs = rawfile.MaxIovs
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec
+ if syscallHeaderBytes > e.maxSyscallHeaderBytes {
+ // We can't fit this packet into this call to sendmmsg().
+ // We could potentially do so if we reduced numIovecs
+ // further, but this might incur considerable extra
+ // copying. Leave it to the next batch instead.
+ break
+ }
+ }
- var mmsgHdr rawfile.MMsgHdr
- mmsgHdr.Msg.Iov = &iovecs[0]
- mmsgHdr.Msg.SetIovlen((len(iovecs)))
- mmsgHdrs = append(mmsgHdrs, mmsgHdr)
- }
+ // We can't easily allocate iovec arrays on the stack here since
+ // they will escape this loop iteration via mmsgHdrs.
+ iovecs := make([]unix.Iovec, 0, numIovecs)
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
+ }
- packets := 0
- for len(mmsgHdrs) > 0 {
- sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
- if err != nil {
- return packets, err
+ var mmsgHdr rawfile.MMsgHdr
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ mmsgHdr.Msg.SetIovlen(len(iovecs))
+ mmsgHdrs = append(mmsgHdrs, mmsgHdr)
+ }
+
+ if len(mmsgHdrs) == 0 {
+ // We can't fit batch[0] into a mmsghdr while staying under
+ // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the
+ // mmsghdr (by using writev) and re-buffer iovecs more aggressively
+ // if necessary (by using e.writevMaxIovs instead of
+ // rawfile.MaxIovs).
+ pkt := batch[0]
+ if err := e.WritePacket(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return packets, err
+ }
+ packets++
+ } else {
+ for len(mmsgHdrs) > 0 {
+ sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
}
- packets += sent
- mmsgHdrs = mmsgHdrs[sent:]
}
return packets, nil
@@ -678,8 +756,9 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti
unix.SetNonblock(fd, true)
return &InjectableEndpoint{endpoint: endpoint{
- fds: []int{fd},
- mtu: mtu,
- caps: capabilities,
+ fds: []int{fd},
+ mtu: mtu,
+ caps: capabilities,
+ writevMaxIovs: rawfile.MaxIovs,
}}
}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index ba92aedbc..43fe57830 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -19,12 +19,66 @@
package rawfile
import (
+ "reflect"
"unsafe"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// SizeofIovec is the size of a unix.Iovec in bytes.
+const SizeofIovec = unsafe.Sizeof(unix.Iovec{})
+
+// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a
+// host system call in a single array.
+const MaxIovs = 1024
+
+// IovecFromBytes returns a unix.Iovec representing bs.
+//
+// Preconditions: len(bs) > 0.
+func IovecFromBytes(bs []byte) unix.Iovec {
+ iov := unix.Iovec{
+ Base: &bs[0],
+ }
+ iov.SetLen(len(bs))
+ return iov
+}
+
+func bytesFromIovec(iov unix.Iovec) (bs []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ sh.Data = uintptr(unsafe.Pointer(iov.Base))
+ sh.Len = int(iov.Len)
+ sh.Cap = int(iov.Len)
+ return
+}
+
+// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) ==
+// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >=
+// max, AppendIovecFromBytes replaces the final iovec in iovs with one that
+// also includes the contents of bs. Note that this implies that
+// AppendIovecFromBytes is only usable when the returned iovec slice is used as
+// the source of a write.
+func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec {
+ if len(bs) == 0 {
+ return iovs
+ }
+ if len(iovs) < max {
+ return append(iovs, IovecFromBytes(bs))
+ }
+ iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...))
+ return iovs
+}
+
+// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
+type MMsgHdr struct {
+ Msg unix.Msghdr
+ Len uint32
+ _ [4]byte
+}
+
+// SizeofMMsgHdr is the size of a MMsgHdr in bytes.
+const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{})
+
// GetMTU determines the MTU of a network interface device.
func GetMTU(name string) (uint32, error) {
fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0)
@@ -137,13 +191,6 @@ func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
}
}
-// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
-type MMsgHdr struct {
- Msg unix.Msghdr
- Len uint32
- _ [4]byte
-}
-
// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
// and stores the received messages in a slice of MMsgHdr structures. If no data
// is available, it will block in a poll() syscall until the file descriptor
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 7656cca6a..4758a99ad 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -26,6 +26,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index af9a3d759..59ac5cb8c 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -88,12 +89,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Input validation.
if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
prefix := "tun"
@@ -108,7 +109,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
d.endpoint = endpoint
@@ -159,7 +160,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
}
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index a72eb1aad..6fa1aee18 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -28,13 +28,13 @@ go_test(
":arp",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 0efa3a926..6515c31e5 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -278,7 +278,7 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
}
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
protocol: p,
nic: nic,
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 94209b026..5fcbfeaa2 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,15 +15,14 @@
package arp_test
import (
- "context"
"fmt"
"testing"
- "time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -31,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
const (
@@ -39,15 +37,6 @@ const (
stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
-
- defaultChannelSize = 1
- defaultMTU = 65536
-
- // eventChanSize defines the size of event channels used by the neighbor
- // cache's event dispatcher. The size chosen here needs to be sufficient to
- // queue all the events received during tests before consumption.
- // If eventChanSize is too small, the tests may deadlock.
- eventChanSize = 32
)
var (
@@ -123,24 +112,6 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
d.C <- e
}
-func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
- select {
- case got := <-d.C:
- if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
- }
- case <-ctx.Done():
- return fmt.Errorf("%s for %s", ctx.Err(), want)
- }
- return nil
-}
-
-func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error {
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- return d.waitForEvent(ctx, want)
-}
-
func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
select {
case event := <-d.C:
@@ -153,55 +124,45 @@ func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
type testContext struct {
s *stack.Stack
linkEP *channel.Endpoint
- nudDisp *arpDispatcher
+ nudDisp arpDispatcher
}
-func newTestContext(t *testing.T) *testContext {
- c := stack.DefaultNUDConfigurations()
- // Transition from Reachable to Stale almost immediately to test if receiving
- // probes refreshes positive reachability.
- c.BaseReachableTime = time.Microsecond
-
- d := arpDispatcher{
- // Create an event channel large enough so the neighbor cache doesn't block
- // while dispatching events. Blocking could interfere with the timing of
- // NUD transitions.
- C: make(chan eventInfo, eventChanSize),
+func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext {
+ t.Helper()
+
+ tc := testContext{
+ nudDisp: arpDispatcher{
+ C: make(chan eventInfo, eventDepth),
+ },
}
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- NUDConfigs: c,
- NUDDisp: &d,
+ tc.s = stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ NUDDisp: &tc.nudDisp,
+ Clock: &faketime.NullClock{},
})
- ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- wep := stack.LinkEndpoint(ep)
+ tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr)
+ tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ wep := stack.LinkEndpoint(tc.linkEP)
if testing.Verbose() {
- wep = sniffer.New(ep)
+ wep = sniffer.New(wep)
}
- if err := s.CreateNIC(nicID, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ if err := tc.s.CreateNIC(nicID, wep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %s", err)
}
- s.SetRouteTable([]tcpip.Route{{
+ tc.s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
NIC: nicID,
}})
- return &testContext{
- s: s,
- linkEP: ep,
- nudDisp: &d,
- }
+ return tc
}
func (c *testContext) cleanup() {
@@ -209,7 +170,7 @@ func (c *testContext) cleanup() {
}
func TestMalformedPacket(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 0, 0)
defer c.cleanup()
v := make(buffer.View, header.ARPSize)
@@ -228,7 +189,7 @@ func TestMalformedPacket(t *testing.T) {
}
func TestDisabledEndpoint(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 0, 0)
defer c.cleanup()
ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber)
@@ -253,7 +214,7 @@ func TestDisabledEndpoint(t *testing.T) {
}
func TestDirectReply(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 0, 0)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
@@ -284,7 +245,7 @@ func TestDirectReply(t *testing.T) {
}
func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
+ c := makeTestContext(t, 1, 1)
defer c.cleanup()
tests := []struct {
@@ -391,17 +352,21 @@ func TestDirectRequest(t *testing.T) {
}
// Verify the sender was saved in the neighbor cache.
- wantEvent := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: stack.NeighborEntry{
- Addr: test.senderAddr,
- LinkAddr: tcpip.LinkAddress(test.senderLinkAddr),
- State: stack.Stale,
- },
- }
- if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
- t.Fatal(err)
+ if got, ok := c.nudDisp.nextEvent(); ok {
+ want := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ entry: stack.NeighborEntry{
+ Addr: test.senderAddr,
+ LinkAddr: test.senderLinkAddr,
+ State: stack.Stale,
+ },
+ }
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
+ t.Errorf("got invalid event (-want +got):\n%s", diff)
+ }
+ } else {
+ t.Fatal("event didn't arrive")
}
neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber)
@@ -589,7 +554,7 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
})
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr)
if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
@@ -663,15 +628,16 @@ func TestLinkAddressRequest(t *testing.T) {
}
func TestDADARPRequestPacket(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{
DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
},
}), ipv4.NewProtocol},
+ Clock: clock,
})
- e := channel.New(1, defaultMTU, stackLinkAddr)
+ e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
@@ -682,7 +648,8 @@ func TestDADARPRequestPacket(t *testing.T) {
t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting)
}
- pkt, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ pkt, ok := e.Read()
if !ok {
t.Fatal("expected to send an ARP request")
}
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
index 5168f5361..1ba4d0d36 100644
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
@@ -251,7 +251,7 @@ func (f *Fragmentation) releaseReassemblersLocked() {
// The list is empty.
break
}
- elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ elapsed := now.Sub(r.createdAt)
if f.timeout > elapsed {
// If the oldest reassembler has not expired, schedule the release
// job so that this function is called back when it has expired.
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
index 7daf64b4a..dadfc28cc 100644
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
@@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) {
highLimit := 3 * lowLimit // Allow at most 3 such packets.
f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
+ if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
+ t.Fatal(err)
+ }
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1"))
+ if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil {
+ t.Fatal(err)
+ }
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2"))
+ if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil {
+ t.Fatal(err)
+ }
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3"))
+ if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil {
+ t.Fatal(err)
+ }
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
memSize := pkt(1, "0").MemSize()
f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
+ if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
+ t.Fatal(err)
+ }
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
+ if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
+ t.Fatal(err)
+ }
if got, want := f.memSize, memSize; got != want {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go
index 56b76a284..5b7e4b361 100644
--- a/pkg/tcpip/network/internal/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go
@@ -35,21 +35,21 @@ type hole struct {
type reassembler struct {
reassemblerEntry
- id FragmentID
- memSize int
- proto uint8
- mu sync.Mutex
- holes []hole
- filled int
- done bool
- creationTime int64
- pkt *stack.PacketBuffer
+ id FragmentID
+ memSize int
+ proto uint8
+ mu sync.Mutex
+ holes []hole
+ filled int
+ done bool
+ createdAt tcpip.MonotonicTime
+ pkt *stack.PacketBuffer
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
- id: id,
- creationTime: clock.NowMonotonic(),
+ id: id,
+ createdAt: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
index a22b712c6..24687cf06 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
@@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled)
}
// Wait for any initially fired timers to complete.
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check(nil); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check(nil); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr2}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) {
if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) {
if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
@@ -347,7 +347,7 @@ func TestNonce(t *testing.T) {
t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
}
- clock.Advance(0)
+ clock.RunImmediatelyScheduledJobs()
for i, want := range test.expectedResults {
if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want {
t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want)
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index d22974b12..671dfbf32 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -611,7 +611,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
// If a timer for any address is already running, it is reset to the new
// random value only if the requested Maximum Response Delay is less than
// the remaining value of the running timer.
- now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds())
+ now := g.opts.Clock.Now()
if info.state == delayingMember {
if info.delayedReportJobFiresAt.IsZero() {
panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress))
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index 0c2b62127..40ab21cb6 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -42,6 +42,10 @@ type MultiCounterIPForwardingStats struct {
// were too big for the outgoing MTU.
PacketTooBig tcpip.MultiCounterStat
+ // HostUnreachable is the number of IP packets received which could not be
+ // successfully forwarded due to an unresolvable next hop.
+ HostUnreachable tcpip.MultiCounterStat
+
// ExtensionHeaderProblem is the number of IP packets which were dropped
// because of a problem encountered when processing an IPv6 extension
// header.
@@ -61,6 +65,7 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem)
m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig)
m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
+ m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable)
}
// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats)
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
index cec3e62c4..a180e5c75 100644
--- a/pkg/tcpip/network/internal/testutil/BUILD
+++ b/pkg/tcpip/network/internal/testutil/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/tcpip/network/internal/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
+ "//pkg/tcpip/tests/integration:__pkg__",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index bd63e0289..771b9173a 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -88,6 +88,7 @@ type testObject struct {
dataCalls int
controlCalls int
+ rawCalls int
}
// checkValues verifies that the transport protocol, data contents, src & dst
@@ -148,6 +149,10 @@ func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpi
t.controlCalls++
}
+func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+ t.rawCalls++
+}
+
// Attach is only implemented to satisfy the LinkEndpoint interface.
func (*testObject) Attach(stack.NetworkDispatcher) {}
@@ -717,7 +722,10 @@ func TestReceive(t *testing.T) {
}
test.handlePacket(t, ep, &nic)
if nic.testObject.dataCalls != 1 {
- t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 1 {
+ t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
}
if got := stat.Value(); got != 1 {
t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
@@ -968,7 +976,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
+ t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 0 {
+ t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls)
}
// Send second segment.
@@ -977,7 +988,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
})
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 1 {
+ t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
}
}
@@ -1310,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1351,7 +1365,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
ip.SetHeaderLength(header.IPv4MinimumSize - 1)
return hdr.View().ToVectorisedView()
@@ -1370,7 +1384,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
@@ -1388,7 +1402,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1430,7 +1444,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
Options: ipv4Options,
})
return hdr.View().ToVectorisedView()
@@ -1469,7 +1483,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
Options: ipv4Options,
})
vv := buffer.View(ip).ToVectorisedView()
@@ -1515,7 +1529,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1560,7 +1574,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1595,7 +1609,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1630,7 +1644,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index d1a82b584..2aa38eb98 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their
- // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
@@ -222,7 +221,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
e.stats.ip.MalformedPacketsReceived.Increment()
}
return
@@ -481,6 +479,22 @@ func (*icmpReasonFragmentationNeeded) isForwarding() bool {
return true
}
+// icmpReasonHostUnreachable is an error in which the host specified in the
+// internet destination field of the datagram is unreachable.
+type icmpReasonHostUnreachable struct{}
+
+func (*icmpReasonHostUnreachable) isICMPReason() {}
+func (*icmpReasonHostUnreachable) isForwarding() bool {
+ // If we hit a Host Unreachable error, then we know we are operating as a
+ // router. As per RFC 792 page 5, Destination Unreachable Message,
+ //
+ // In addition, in some networks, the gateway may be able to determine
+ // if the internet destination host is unreachable. Gateways in these
+ // networks may send destination unreachable messages to the source host
+ // when the destination host is unreachable.
+ return true
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
@@ -537,7 +551,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
defer route.Release()
p.mu.Lock()
- netEP, ok := p.mu.eps[pkt.NICID]
+ // We retrieve an endpoint using the newly constructed route's NICID rather
+ // than the packet's NICID. The packet's NICID corresponds to the NIC on
+ // which it arrived, which isn't necessarily the same as the NIC on which it
+ // will be transmitted. On the other hand, the route's NIC *is* guaranteed
+ // to be the NIC on which the packet will be transmitted.
+ netEP, ok := p.mu.eps[route.NICID()]
p.mu.Unlock()
if !ok {
return &tcpip.ErrNotConnected{}
@@ -653,6 +672,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4NetUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonHostUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4HostUnreachable)
+ counter = sent.dstUnreachable
case *icmpReasonFragmentationNeeded:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 23178277a..44c85bdb8 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -82,7 +82,7 @@ type endpoint struct {
protocol *protocol
stats sharedStats
- // enabled is set to 1 when the enpoint is enabled and 0 when it is
+ // enabled is set to 1 when the endpoint is enabled and 0 when it is
// disabled.
//
// Must be accessed using atomic operations.
@@ -104,6 +104,16 @@ type endpoint struct {
// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint.
func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
+ // If we are operating as a router, return an ICMP error to the original
+ // packet's sender.
+ if pkt.NetworkPacketInfo.IsForwardedPacket {
+ // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP
+ // errors to local endpoints.
+ e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt)
+ e.stats.ip.Forwarding.Errors.Increment()
+ e.stats.ip.Forwarding.HostUnreachable.Increment()
+ return
+ }
// handleControl expects the entire offending packet to be in the packet
// buffer's data field.
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -322,8 +332,8 @@ func (e *endpoint) DefaultTTL() uint8 {
return e.protocol.DefaultTTL()
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
+// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the
+// network layer max header length.
func (e *endpoint) MTU() uint32 {
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize)
if err != nil {
@@ -338,7 +348,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
-// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+// NetworkProtocolNumber implements stack.NetworkEndpoint.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
@@ -353,7 +363,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
if hdrLen > header.IPv4MaximumHeaderSize {
return &tcpip.ErrMessageTooLong{}
}
- ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
+ ipH := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := pkt.Size()
if length > math.MaxUint16 {
return &tcpip.ErrMessageTooLong{}
@@ -362,7 +372,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
- ip.Encode(&header.IPv4Fields{
+ ipH.Encode(&header.IPv4Fields{
TotalLength: uint16(length),
ID: uint16(id),
TTL: params.TTL,
@@ -372,7 +382,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
DstAddr: dstAddr,
Options: options,
})
- ip.SetChecksum(^ip.CalculateChecksum())
+ ipH.SetChecksum(^ipH.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
return nil
}
@@ -381,7 +391,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
// original packet.
-func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
+func (e *endpoint) handleFragments(_ *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
// Round the MTU down to align to 8 bytes.
fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -419,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -490,7 +500,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
return nil
}
-// WritePackets implements stack.NetworkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.
func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop()&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
@@ -593,34 +603,30 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return &tcpip.ErrMalformedHeader{}
}
- ip := header.IPv4(h)
+ ipH := header.IPv4(h)
// Always set the total length.
pktSize := pkt.Data().Size()
- ip.SetTotalLength(uint16(pktSize))
+ ipH.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
- if ip.SourceAddress() == header.IPv4Any {
- ip.SetSourceAddress(r.LocalAddress())
+ if ipH.SourceAddress() == header.IPv4Any {
+ ipH.SetSourceAddress(r.LocalAddress())
}
- // Set the destination. If the packet already included a destination, it will
- // be part of the route anyways.
- ip.SetDestinationAddress(r.RemoteAddress())
-
// Set the packet ID when zero.
- if ip.ID() == 0 {
+ if ipH.ID() == 0 {
// RFC 6864 section 4.3 mandates uniqueness of ID values for
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
- if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ if ipH.Flags()&header.IPv4FlagDontFragment == 0 || ipH.Flags()&header.IPv4FlagMoreFragments != 0 || ipH.FragmentOffset() > 0 {
+ ipH.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
// Always set the checksum.
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
+ ipH.SetChecksum(0)
+ ipH.SetChecksum(^ipH.CalculateChecksum())
// Populate the packet buffer's network header and don't allow an invalid
// packet to be sent.
@@ -710,7 +716,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -826,7 +833,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -845,10 +852,17 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
+ // Raw socket packets are delivered based solely on the transport protocol
+ // number. We only require that the packet be valid IPv4, and that they not
+ // be fragmented.
+ if !h.More() && h.FragmentOffset() == 0 {
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
+ }
+
pkt.NICID = e.nic.ID()
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -898,7 +912,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
case *ip.ErrNoRoute:
stats.ip.Forwarding.Unrouteable.Increment()
case *ip.ErrParameterProblem:
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
stats.ip.MalformedPacketsReceived.Increment()
case *ip.ErrMessageTooLong:
stats.ip.Forwarding.PacketTooBig.Increment()
@@ -911,8 +924,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -935,7 +947,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
e.stats.ip.MalformedPacketsReceived.Increment()
}
return
@@ -985,8 +996,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
// The reassembler doesn't take care of fixing up the header, so we need
// to do it here.
- h.SetTotalLength(uint16(pkt.Data().Size() + len((h))))
+ h.SetTotalLength(uint16(pkt.Data().Size() + len(h)))
h.SetFlagsFragmentOffset(0, 0)
+
+ // Now that the packet is reassembled, it can be sent to raw sockets.
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
stats.ip.PacketsDelivered.Increment()
@@ -1008,7 +1022,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
stats.ip.MalformedPacketsReceived.Increment()
}
return
@@ -1217,13 +1230,13 @@ func (p *protocol) DefaultPrefixLen() int {
return header.IPv4AddressSize * 8
}
-// ParseAddresses implements NetworkProtocol.ParseAddresses.
+// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
return h.SourceAddress(), h.DestinationAddress()
}
-// SetOption implements NetworkProtocol.SetOption.
+// SetOption implements stack.NetworkProtocol.
func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -1234,7 +1247,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E
}
}
-// Option implements NetworkProtocol.Option.
+// Option implements stack.NetworkProtocol.
func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -1255,10 +1268,10 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
-// Close implements stack.TransportProtocol.Close.
+// Close implements stack.TransportProtocol.
func (*protocol) Close() {}
-// Wait implements stack.TransportProtocol.Wait.
+// Wait implements stack.TransportProtocol.
func (*protocol) Wait() {}
// parseAndValidate parses the packet (including its transport layer header) and
@@ -1296,7 +1309,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool)
return h, true
}
-// Parse implements stack.NetworkProtocol.Parse.
+// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
if ok := parse.IPv4(pkt); !ok {
return 0, false, false
@@ -1326,7 +1339,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error
networkMTU = MaxTotalSize
}
- return networkMTU - uint32(networkHeaderSize), nil
+ return networkMTU - networkHeaderSize, nil
}
func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool {
@@ -1735,9 +1748,8 @@ type optionTracker struct {
//
// If there were no errors during parsing, the new set of options is returned as
// a new buffer.
-func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) {
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, opts header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) {
stats := e.stats.ip
- opts := header.IPv4Options(orig)
optIter := opts.MakeIterator()
// Except NOP, each option must only appear at most once (RFC 791 section 3.1,
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index da9cc0ae8..4a4448cf9 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -16,7 +16,6 @@ package ipv4_test
import (
"bytes"
- "context"
"encoding/hex"
"fmt"
"io/ioutil"
@@ -36,10 +35,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
+ iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -112,10 +111,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-type forwardedPacket struct {
- fragments []fragmentInfo
-}
-
func TestForwarding(t *testing.T) {
const (
incomingNICID = 1
@@ -134,11 +129,11 @@ func TestForwarding(t *testing.T) {
PrefixLen: 8,
}
outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2")
- remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2")
- unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2")
- multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0")
- linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0")
+ remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2")
+ remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2")
+ unreachableIPv4Addr := testutil.MustParse4("12.0.0.2")
+ multicastIPv4Addr := testutil.MustParse4("225.0.0.0")
+ linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0")
tests := []struct {
name string
@@ -402,13 +397,13 @@ func TestForwarding(t *testing.T) {
totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
- icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv4Echo)
- icmp.SetCode(header.ICMPv4UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(^header.Checksum(icmp, 0))
+ icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
+ icmpH.SetIdent(randomIdent)
+ icmpH.SetSequence(randomSequence)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
ip := header.IPv4(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLength),
@@ -453,7 +448,7 @@ func TestForwarding(t *testing.T) {
return len(hdr.View())
}
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
checker.SrcAddr(incomingIPv4Addr.Address),
checker.DstAddr(test.sourceAddr),
checker.TTL(ipv4.DefaultTTL),
@@ -461,7 +456,7 @@ func TestForwarding(t *testing.T) {
checker.ICMPv4Checksum(),
checker.ICMPv4Type(test.icmpType),
checker.ICMPv4Code(test.icmpCode),
- checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
+ checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]),
),
)
} else if ok {
@@ -470,7 +465,7 @@ func TestForwarding(t *testing.T) {
if test.expectPacketForwarded {
if len(test.expectedFragmentsForwarded) != 0 {
- fragmentedPackets := []*stack.PacketBuffer{}
+ var fragmentedPackets []*stack.PacketBuffer
for i := 0; i < len(test.expectedFragmentsForwarded); i++ {
reply, ok = outgoingEndpoint.Read()
if !ok {
@@ -487,7 +482,7 @@ func TestForwarding(t *testing.T) {
// maximum IP header size and the maximum size allocated for link layer
// headers. In this case, no size is allocated for link layer headers.
expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize
- if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
+ if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
t.Error(err)
}
} else {
@@ -496,7 +491,7 @@ func TestForwarding(t *testing.T) {
t.Fatal("expected ICMP Echo packet through outgoing NIC")
}
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
checker.SrcAddr(test.sourceAddr),
checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
@@ -1212,15 +1207,15 @@ func TestIPv4Sanity(t *testing.T) {
}
totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(totalLen))
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
// Specify ident/seq to make sure we get the same in the response.
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv4Echo)
- icmp.SetCode(header.ICMPv4UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(^header.Checksum(icmp, 0))
+ icmpH.SetIdent(randomIdent)
+ icmpH.SetSequence(randomSequence)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
ip := header.IPv4(hdr.Prepend(ipHeaderLength))
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
@@ -1315,7 +1310,7 @@ func TestIPv4Sanity(t *testing.T) {
checker.ICMPv4Type(test.ICMPType),
checker.ICMPv4Code(test.ICMPCode),
checker.ICMPv4Pointer(test.paramProblemPointer),
- checker.ICMPv4Payload([]byte(hdr.View())),
+ checker.ICMPv4Payload(hdr.View()),
),
)
return
@@ -1334,7 +1329,7 @@ func TestIPv4Sanity(t *testing.T) {
checker.ICMPv4(
checker.ICMPv4Type(test.ICMPType),
checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Payload([]byte(hdr.View())),
+ checker.ICMPv4Payload(hdr.View()),
),
)
return
@@ -1546,9 +1541,9 @@ func TestFragmentationWritePacket(t *testing.T) {
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
source := pkt.Clone()
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -1602,7 +1597,7 @@ func TestFragmentationWritePackets(t *testing.T) {
insertAfter: 1,
},
}
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
+ tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
for _, test := range writePacketsTests {
t.Run(test.description, func(t *testing.T) {
@@ -1612,13 +1607,13 @@ func TestFragmentationWritePackets(t *testing.T) {
for i := 0; i < test.insertBefore; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
pkts.PushBack(pkt.Clone())
for i := 0; i < test.insertAfter; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
@@ -1726,8 +1721,8 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -2301,7 +2296,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
checker.ICMPv4Type(header.ICMPv4TimeExceeded),
checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
checker.ICMPv4Checksum(),
- checker.ICMPv4Payload([]byte(firstFragmentSent)),
+ checker.ICMPv4Payload(firstFragmentSent),
),
)
})
@@ -2705,7 +2700,7 @@ func TestReceiveFragments(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
RawFactory: raw.EndpointFactory{},
})
- e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
+ e := channel.New(0, 1280, "\xf0\x00")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -2948,7 +2943,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
+ ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
@@ -3035,7 +3030,7 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string)
return false, false
}
-func TestPacketQueing(t *testing.T) {
+func TestPacketQueuing(t *testing.T) {
const nicID = 1
var (
@@ -3074,7 +3069,7 @@ func TestPacketQueing(t *testing.T) {
Length: header.UDPMinimumSize,
})
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(header.UDP([]byte{}), sum)
+ sum = header.Checksum(nil, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
@@ -3090,7 +3085,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -3133,7 +3128,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -3157,9 +3152,11 @@ func TestPacketQueing(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e := channel.New(1, defaultMTU, host1NICLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -3182,7 +3179,8 @@ func TestPacketQueing(t *testing.T) {
// Wait for a ARP request since link address resolution should be
// performed.
{
- p, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -3223,6 +3221,7 @@ func TestPacketQueing(t *testing.T) {
}
// Expect the response now that the link address has resolved.
+ clock.RunImmediatelyScheduledJobs()
test.checkResp(t, e)
// Since link resolution was already performed, it shouldn't be performed
@@ -3244,8 +3243,8 @@ func TestCloseLocking(t *testing.T) {
)
var (
- src = tcptestutil.MustParse4("16.0.0.1")
- dst = tcptestutil.MustParse4("16.0.0.2")
+ src = testutil.MustParse4("16.0.0.1")
+ dst = testutil.MustParse4("16.0.0.2")
)
s := stack.New(stack.Options{
@@ -3253,7 +3252,7 @@ func TestCloseLocking(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- // Perform NAT so that the endoint tries to search for a sibling endpoint
+ // Perform NAT so that the endpoint tries to search for a sibling endpoint
// which ends up taking the protocol and endpoint lock (in that order).
table := stack.Table{
Rules: []stack.Rule{
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 307e1972d..94caaae6c 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) {
sent := e.stats.icmp.packetsSent
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set. See icmp/protocol.go:protocol.Parse for a full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize)
if !ok {
received.invalid.Increment()
@@ -1029,6 +1029,26 @@ func (*icmpReasonNetUnreachable) respondsToMulticast() bool {
return false
}
+// icmpReasonHostUnreachable is an error in which the host specified in the
+// internet destination field of the datagram is unreachable.
+type icmpReasonHostUnreachable struct{}
+
+func (*icmpReasonHostUnreachable) isICMPReason() {}
+func (*icmpReasonHostUnreachable) isForwarding() bool {
+ // If we hit a Host Unreachable error, then we know we are operating as a
+ // router. As per RFC 4443 page 8, Destination Unreachable Message,
+ //
+ // If the reason for the failure to deliver cannot be mapped to any of
+ // other codes, the Code field is set to 3. Example of such cases are
+ // an inability to resolve the IPv6 destination address into a
+ // corresponding link address, or a link-specific problem of some sort.
+ return true
+}
+
+func (*icmpReasonHostUnreachable) respondsToMulticast() bool {
+ return false
+}
+
// icmpReasonFragmentationNeeded is an error where a packet is to big to be sent
// out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message.
type icmpReasonPacketTooBig struct{}
@@ -1143,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
defer route.Release()
p.mu.Lock()
- netEP, ok := p.mu.eps[pkt.NICID]
+ // We retrieve an endpoint using the newly constructed route's NICID rather
+ // than the packet's NICID. The packet's NICID corresponds to the NIC on
+ // which it arrived, which isn't necessarily the same as the NIC on which it
+ // will be transmitted. On the other hand, the route's NIC *is* guaranteed
+ // to be the NIC on which the packet will be transmitted.
+ netEP, ok := p.mu.eps[route.NICID()]
p.mu.Unlock()
if !ok {
return &tcpip.ErrNotConnected{}
@@ -1222,6 +1247,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonHostUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6AddressUnreachable)
+ counter = sent.dstUnreachable
case *icmpReasonPacketTooBig:
icmpHdr.SetType(header.ICMPv6PacketTooBig)
icmpHdr.SetCode(header.ICMPv6UnusedCode)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 040cd4bc8..7c2a3e56b 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,17 +16,16 @@ package ipv6
import (
"bytes"
- "context"
"net"
"reflect"
"strings"
"testing"
- "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -46,16 +45,12 @@ const (
defaultChannelSize = 1
defaultMTU = 65536
- // Extra time to use when waiting for an async event to occur.
- defaultAsyncPositiveEventTimeout = 30 * time.Second
-
arbitraryHopLimit = 42
)
var (
lladdr0 = header.LinkLocalAddr(linkAddr0)
lladdr1 = header.LinkLocalAddr(linkAddr1)
- lladdr2 = header.LinkLocalAddr(linkAddr2)
)
type stubLinkEndpoint struct {
@@ -95,6 +90,10 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st
return stack.TransportPacketHandled
}
+func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+ // No-op.
+}
+
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
@@ -371,6 +370,8 @@ type testContext struct {
linkEP0 *channel.Endpoint
linkEP1 *channel.Endpoint
+
+ clock *faketime.ManualClock
}
type endpointWithResolutionCapability struct {
@@ -382,15 +383,19 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
}
func newTestContext(t *testing.T) *testContext {
+ clock := faketime.NewManualClock()
c := &testContext{
s0: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ Clock: clock,
}),
s1: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ Clock: clock,
}),
+ clock: clock,
}
c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
@@ -457,10 +462,14 @@ type routeArgs struct {
remoteLinkAddr tcpip.LinkAddress
}
-func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
+func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pi, _ := args.src.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ pi, ok := args.src.Read()
+ if !ok {
+ t.Fatal("packet didn't arrive")
+ }
{
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -533,7 +542,7 @@ func TestLinkResolution(t *testing.T) {
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
} {
- routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
+ routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) {
if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
}
@@ -544,7 +553,7 @@ func TestLinkResolution(t *testing.T) {
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
} {
- routeICMPv6Packet(t, args, nil)
+ routeICMPv6Packet(t, c.clock, args, nil)
}
}
@@ -1309,7 +1318,7 @@ func TestPacketQueing(t *testing.T) {
Length: header.UDPMinimumSize,
})
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(header.UDP([]byte{}), sum)
+ sum = header.Checksum(nil, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -1325,7 +1334,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -1371,7 +1380,7 @@ func TestPacketQueing(t *testing.T) {
}))
},
checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.ReadContext(context.Background())
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -1396,9 +1405,11 @@ func TestPacketQueing(t *testing.T) {
e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -1421,7 +1432,8 @@ func TestPacketQueing(t *testing.T) {
// Wait for a neighbor solicitation since link address resolution should
// be performed.
{
- p, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, ok := e.Read()
if !ok {
t.Fatalf("timed out waiting for packet")
}
@@ -1475,6 +1487,7 @@ func TestPacketQueing(t *testing.T) {
}
// Expect the response now that the link address has resolved.
+ clock.RunImmediatelyScheduledJobs()
test.checkResp(t, e)
// Since link resolution was already performed, it shouldn't be performed
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 95e11ac51..b1aec5312 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -184,7 +184,6 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
stats sharedStats
// enabled is set to 1 when the endpoint is enabled and 0 when it is
@@ -231,11 +230,11 @@ type endpoint struct {
// If the NIC was created with a name, it is passed to NICNameFromID.
//
// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
-// generated for the same prefix on differnt NICs.
+// generated for the same prefix on different NICs.
type NICNameFromID func(tcpip.NICID, string) string
// OpaqueInterfaceIdentifierOptions holds the options related to the generation
-// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+// of opaque interface identifiers (IIDs) as defined by RFC 7217.
type OpaqueInterfaceIdentifierOptions struct {
// NICNameFromID is a function that returns a stable name for a specified NIC,
// even if the NIC ID changes over time.
@@ -282,6 +281,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber {
// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint.
func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
+ // If we are operating as a router, we should return an ICMP error to the
+ // original packet's sender.
+ if pkt.NetworkPacketInfo.IsForwardedPacket {
+ // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP
+ // errors to local endpoints.
+ e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt)
+ e.stats.ip.Forwarding.Errors.Increment()
+ e.stats.ip.Forwarding.HostUnreachable.Increment()
+ return
+ }
// handleControl expects the entire offending packet to be in the packet
// buffer's data field.
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -335,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
defer e.mu.Unlock()
- e.mu.ndp.invalidateDefaultRouter(rtr)
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr})
}
// SetNDPConfigurations implements NDPEndpoint.
@@ -536,7 +548,7 @@ func (e *endpoint) Enable() tcpip.Error {
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
// state.
//
- // Addresses may have aleady completed DAD but in the time since the endpoint
+ // Addresses may have already completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
var err tcpip.Error
e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
@@ -641,8 +653,8 @@ func (e *endpoint) DefaultTTL() uint8 {
return e.protocol.DefaultTTL()
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
+// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the
+// network layer max header length.
func (e *endpoint) MTU() uint32 {
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize)
if err != nil {
@@ -666,8 +678,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params
if length > math.MaxUint16 {
return &tcpip.ErrMessageTooLong{}
}
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
- ip.Encode(&header.IPv6Fields{
+ header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)).Encode(&header.IPv6Fields{
PayloadLength: uint16(length),
TransportProtocol: params.Protocol,
HopLimit: params.TTL,
@@ -747,9 +758,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -818,7 +829,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
return nil
}
-// WritePackets implements stack.NetworkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.
func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop()&stack.PacketLoop != 0 {
panic("not implemented")
@@ -909,21 +920,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return &tcpip.ErrMalformedHeader{}
}
- ip := header.IPv6(h)
+ ipH := header.IPv6(h)
// Always set the payload length.
pktSize := pkt.Data().Size()
- ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
+ ipH.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
// Set the source address when zero.
- if ip.SourceAddress() == header.IPv6Any {
- ip.SetSourceAddress(r.LocalAddress())
+ if ipH.SourceAddress() == header.IPv6Any {
+ ipH.SetSourceAddress(r.LocalAddress())
}
- // Set the destination. If the packet already included a destination, it will
- // be part of the route anyways.
- ip.SetDestinationAddress(r.RemoteAddress())
-
// Populate the packet buffer's network header and don't allow an invalid
// packet to be sent.
//
@@ -983,7 +990,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -1096,7 +1104,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -1115,10 +1123,14 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
+ // Raw socket packets are delivered based solely on the transport protocol
+ // number. We only require that the packet be valid IPv6.
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
+
pkt.NICID = e.nic.ID()
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1167,8 +1179,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1610,7 +1621,7 @@ func (e *endpoint) Close() {
e.protocol.forgetEndpoint(e.nic.ID())
}
-// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+// NetworkProtocolNumber implements stack.NetworkEndpoint.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
@@ -1619,8 +1630,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
}
@@ -1661,8 +1672,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
@@ -2004,7 +2015,7 @@ func (p *protocol) DefaultPrefixLen() int {
return header.IPv6AddressSize * 8
}
-// ParseAddresses implements NetworkProtocol.ParseAddresses.
+// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
return h.SourceAddress(), h.DestinationAddress()
@@ -2081,7 +2092,7 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// SetOption implements NetworkProtocol.SetOption.
+// SetOption implements stack.NetworkProtocol.
func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -2092,7 +2103,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E
}
}
-// Option implements NetworkProtocol.Option.
+// Option implements stack.NetworkProtocol.
func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
@@ -2113,10 +2124,10 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
-// Close implements stack.TransportProtocol.Close.
+// Close implements stack.TransportProtocol.
func (*protocol) Close() {}
-// Wait implements stack.TransportProtocol.Wait.
+// Wait implements stack.TransportProtocol.
func (*protocol) Wait() {}
// parseAndValidate parses the packet (including its transport layer header) and
@@ -2150,7 +2161,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool)
return h, true
}
-// Parse implements stack.NetworkProtocol.Parse.
+// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
@@ -2180,7 +2191,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error
return 0, &tcpip.ErrMalformedHeader{}
}
- networkMTU := linkMTU - uint32(networkHeadersLen)
+ networkMTU := linkMTU - networkHeadersLen
if networkMTU > maxPayloadSize {
networkMTU = maxPayloadSize
}
@@ -2199,7 +2210,7 @@ type Options struct {
// Note, setting this to true does not mean that a link-local address is
// assigned right away, or at all. If Duplicate Address Detection is enabled,
// an address is only assigned if it successfully resolves. If it fails, no
- // further attempts are made to auto-generate a link-local adddress.
+ // further attempts are made to auto-generate a link-local address.
//
// The generated link-local address follows RFC 4291 Appendix A guidelines.
AutoGenLinkLocal bool
@@ -2215,7 +2226,7 @@ type Options struct {
// TempIIDSeed is used to seed the initial temporary interface identifier
// history value used to generate IIDs for temporary SLAAC addresses.
//
- // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // Temporary SLAAC addresses are short-lived addresses which are unpredictable
// and random from the perspective of other nodes on the network. It is
// recommended that the seed be a random byte buffer of at least
// header.IIDSize bytes to make sure that temporary SLAAC addresses are
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index afc6c3547..d2a23fd4f 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -129,7 +129,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
// UDP checksum
- sum = header.Checksum(header.UDP([]byte{}), sum)
+ sum = header.Checksum(nil, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
payloadLength := hdr.UsedLength()
@@ -402,7 +402,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}{
{
name: "None",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return nil, nextHdr },
shouldAccept: true,
expectICMP: false,
},
@@ -612,8 +612,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
{
name: "No next header",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{},
- noNextHdrID
+ return nil, noNextHdrID
},
shouldAccept: false,
expectICMP: false,
@@ -1160,7 +1159,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0},
ipv6Payload1Addr1ToAddr2,
},
@@ -1180,7 +1179,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0},
ipv6Payload3Addr1ToAddr2,
},
@@ -1202,7 +1201,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1218,7 +1217,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1240,7 +1239,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1256,7 +1255,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1278,7 +1277,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1296,7 +1295,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 8, More = false, ID = 1
// NextHeader value is different than the one in the first fragment, so
// this NextHeader should be ignored.
- buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1318,7 +1317,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[:64],
},
@@ -1334,7 +1333,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[64:],
},
@@ -1356,7 +1355,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[:63],
},
@@ -1372,7 +1371,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload3Addr1ToAddr2[63:],
},
@@ -1394,7 +1393,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1410,7 +1409,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1432,7 +1431,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
@@ -1448,10 +1447,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ []byte{uint8(header.UDPProtocolNumber), 0,
udpMaximumSizeMinus15 >> 8,
udpMaximumSizeMinus15 & 0xff,
- 0, 0, 0, 1}),
+ 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
@@ -1473,7 +1472,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
@@ -1489,10 +1488,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ []byte{uint8(header.UDPProtocolNumber), 0,
udpMaximumSizeMinus15 >> 8,
(udpMaximumSizeMinus15 & 0xff) + 1,
- 0, 0, 0, 1}),
+ 0, 0, 0, 1},
ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
@@ -1514,12 +1513,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 0.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1535,12 +1534,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 0.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1562,12 +1561,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 1.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1583,12 +1582,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header.
//
// Segments left = 1.
- buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+ []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5},
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1610,12 +1609,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header.
//
// Segments left = 0.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1631,7 +1630,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1653,12 +1652,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header.
//
// Segments left = 1.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1674,7 +1673,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1699,12 +1698,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header (part 1)
//
// Segments left = 0.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5},
},
),
},
@@ -1721,10 +1720,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 1, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1},
// Routing extension header (part 2)
- buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+ []byte{6, 7, 8, 9, 10, 11, 12, 13},
ipv6Payload1Addr1ToAddr2,
},
@@ -1749,12 +1748,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
// Routing extension header (part 1)
//
// Segments left = 1.
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
+ []byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5},
},
),
},
@@ -1771,10 +1770,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 1, More = false, ID = 1
- buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+ []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1},
// Routing extension header (part 2)
- buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+ []byte{6, 7, 8, 9, 10, 11, 12, 13},
ipv6Payload1Addr1ToAddr2,
},
@@ -1798,7 +1797,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1816,7 +1815,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1},
ipv6Payload2Addr1ToAddr2,
},
@@ -1832,7 +1831,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1854,7 +1853,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1870,7 +1869,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2},
ipv6Payload2Addr1ToAddr2[:32],
},
@@ -1886,7 +1885,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1902,7 +1901,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 4, More = false, ID = 2
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2},
ipv6Payload2Addr1ToAddr2[32:],
},
@@ -1924,7 +1923,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[:64],
},
@@ -1940,7 +1939,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
ipv6Payload1Addr3ToAddr2[:32],
},
@@ -1956,7 +1955,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
ipv6Payload1Addr1ToAddr2[64:],
},
@@ -1972,7 +1971,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment extension header.
//
// Fragment offset = 4, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}),
+ []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1},
ipv6Payload1Addr3ToAddr2[32:],
},
@@ -2208,7 +2207,7 @@ func TestInvalidIPv6Fragments(t *testing.T) {
checker.ICMPv6Type(test.expectICMPType),
checker.ICMPv6Code(test.expectICMPCode),
checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
- checker.ICMPv6Payload([]byte(expectICMPPayload)),
+ checker.ICMPv6Payload(expectICMPPayload),
),
)
})
@@ -2459,7 +2458,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
checker.ICMPv6(
checker.ICMPv6Type(header.ICMPv6TimeExceeded),
checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
- checker.ICMPv6Payload([]byte(firstFragmentSent)),
+ checker.ICMPv6Payload(firstFragmentSent),
),
)
})
@@ -2795,11 +2794,7 @@ var fragmentationTests = []struct {
}
func TestFragmentationWritePacket(t *testing.T) {
- const (
- ttl = 42
- tos = stack.DefaultTOS
- transportProto = tcp.ProtocolNumber
- )
+ const ttl = 42
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
@@ -3335,7 +3330,7 @@ func TestForwarding(t *testing.T) {
}
transportProtocol := header.ICMPv6ProtocolNumber
- extHdrBytes := []byte{}
+ var extHdrBytes []byte
extHdrChecker := checker.IPv6ExtHdr()
if test.extHdr != nil {
nextHdrID := hopByHopExtHdrID
@@ -3349,15 +3344,15 @@ func TestForwarding(t *testing.T) {
totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
- icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
-
- icmp.SetIdent(randomIdent)
- icmp.SetSequence(randomSequence)
- icmp.SetType(header.ICMPv6EchoRequest)
- icmp.SetCode(header.ICMPv6UnusedCode)
- icmp.SetChecksum(0)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
+ icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
+
+ icmpH.SetIdent(randomIdent)
+ icmpH.SetSequence(randomSequence)
+ icmpH.SetType(header.ICMPv6EchoRequest)
+ icmpH.SetCode(header.ICMPv6UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH,
Src: test.sourceAddr,
Dst: test.destAddr,
}))
@@ -3395,14 +3390,14 @@ func TestForwarding(t *testing.T) {
return len(hdr.View())
}
- checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
checker.SrcAddr(incomingIPv6Addr.Address),
checker.DstAddr(test.sourceAddr),
checker.TTL(DefaultTTL),
checker.ICMPv6(
checker.ICMPv6Type(test.icmpType),
checker.ICMPv6Code(test.icmpCode),
- checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
+ checker.ICMPv6Payload(hdr.View()[:expectedICMPPayloadLength()]),
),
)
@@ -3419,7 +3414,7 @@ func TestForwarding(t *testing.T) {
t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
}
- checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.IPv6WithExtHdr(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
checker.SrcAddr(test.sourceAddr),
checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 71d1c3e28..bc9cf6999 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -16,6 +16,7 @@ package ipv6_test
import (
"bytes"
+ "math/rand"
"testing"
"time"
@@ -138,7 +139,8 @@ func TestSendQueuedMLDReports(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
+ SecureRNG: &secureRNG,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: test.dadTransmits,
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index f0ff111c5..9cd283eba 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -16,7 +16,6 @@ package ipv6
import (
"fmt"
- "math/rand"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -79,13 +78,13 @@ const (
// we cannot have a negative delay.
minimumMaxRtrSolicitationDelay = 0
- // MaxDiscoveredDefaultRouters is the maximum number of discovered
- // default routers. The stack should stop discovering new routers after
- // discovering MaxDiscoveredDefaultRouters routers.
+ // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link
+ // routes. The stack should stop discovering new off-link routes after
+ // this limit is reached.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- MaxDiscoveredDefaultRouters = 10
+ MaxDiscoveredOffLinkRoutes = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
@@ -128,25 +127,17 @@ const (
// maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
// SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
maxSLAACAddrLocalRegenAttempts = 10
-)
-var (
// MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
// Lifetime to update the valid lifetime of a generated address by
// SLAAC.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// Min = 2hrs.
MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
// MaxDesyncFactor is the upper bound for the preferred lifetime's desync
// factor for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// Must be greater than 0.
//
// Max = 10m (from RFC 4941 section 5).
@@ -155,9 +146,6 @@ var (
// MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
// maximum preferred lifetime for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// This value guarantees that a temporary address is preferred for at
// least 1hr if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
@@ -165,9 +153,6 @@ var (
// MinMaxTempAddrValidLifetime is the minimum value allowed for the
// maximum valid lifetime for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// This value guarantees that a temporary address is valid for at least
// 2hrs if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrValidLifetime = 2 * time.Hour
@@ -215,28 +200,23 @@ type NDPDispatcher interface {
// is also not permitted to call into the stack.
OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult)
- // OnDefaultRouterDiscovered is called when a new default router is
- // discovered. Implementations must return true if the newly discovered
- // router should be remembered.
+ // OnOffLinkRouteUpdated is called when an off-link route is updated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool
+ OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference)
- // OnDefaultRouterInvalidated is called when a discovered default router that
- // was remembered is invalidated.
+ // OnOffLinkRouteInvalidated is called when an off-link route is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address)
+ OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address)
// OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
- // Implementations must return true if the newly discovered on-link prefix
- // should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool
+ OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet)
// OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
// was remembered is invalidated.
@@ -246,13 +226,11 @@ type NDPDispatcher interface {
OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet)
// OnAutoGenAddress is called when a new prefix with its autonomous address-
- // configuration flag set is received and SLAAC was performed. Implementations
- // may prevent the stack from assigning the address to the NIC by returning
- // false.
+ // configuration flag set is received and SLAAC was performed.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
- OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix)
// OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
// is deprecated, but is still considered valid. Note, if an address is
@@ -470,6 +448,11 @@ type timer struct {
timer tcpip.Timer
}
+type offLinkRoute struct {
+ dest tcpip.Subnet
+ router tcpip.Address
+}
+
// ndpState is the per-Interface NDP state.
type ndpState struct {
// Do not allow overwriting this state.
@@ -484,8 +467,8 @@ type ndpState struct {
// The DAD timers to send the next NS message, or resolve the address.
dad ip.DAD
- // The default routers discovered through Router Advertisements.
- defaultRouters map[tcpip.Address]defaultRouterState
+ // The off-link routes discovered through Router Advertisements.
+ offLinkRoutes map[offLinkRoute]offLinkRouteState
// rtrSolicitTimer is the timer used to send the next router solicitation
// message.
@@ -513,10 +496,12 @@ type ndpState struct {
temporaryAddressDesyncFactor time.Duration
}
-// defaultRouterState holds data associated with a default router discovered by
+// offLinkRouteState holds data associated with an off-link route discovered by
// a Router Advertisement (RA).
-type defaultRouterState struct {
- // Job to invalidate the default router.
+type offLinkRouteState struct {
+ prf header.NDPRoutePreference
+
+ // Job to invalidate the route.
//
// Must not be nil.
invalidationJob *tcpip.Job
@@ -549,7 +534,7 @@ type tempSLAACAddrState struct {
// Must not be nil.
regenJob *tcpip.Job
- createdAt time.Time
+ createdAt tcpip.MonotonicTime
// The address's endpoint.
//
@@ -572,11 +557,11 @@ type slaacPrefixState struct {
// Must not be nil.
invalidationJob *tcpip.Job
- // Nonzero only when the address is not valid forever.
- validUntil time.Time
+ // nil iff the address is valid forever.
+ validUntil *tcpip.MonotonicTime
- // Nonzero only when the address is not preferred forever.
- preferredUntil time.Time
+ // nil iff the address is preferred forever.
+ preferredUntil *tcpip.MonotonicTime
// State associated with the stable address generated for the prefix.
stableAddr struct {
@@ -734,30 +719,22 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Is the IPv6 endpoint configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
- rtr, ok := ndp.defaultRouters[ip]
- rl := ra.RouterLifetime()
- switch {
- case !ok && rl != 0:
- // This is a new default router we are discovering.
+ prf := ra.DefaultRouterPreference()
+ if prf == header.ReservedRoutePreference {
+ // As per RFC 4191 section 2.2,
//
- // Only remember it if we currently know about less than
- // MaxDiscoveredDefaultRouters routers.
- if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
- ndp.rememberDefaultRouter(ip, rl)
- }
-
- case ok && rl != 0:
- // This is an already discovered default router. Update
- // the invalidation job.
- rtr.invalidationJob.Cancel()
- rtr.invalidationJob.Schedule(rl)
- ndp.defaultRouters[ip] = rtr
-
- case ok && rl == 0:
- // We know about the router but it is no longer to be
- // used as a default router so invalidate it.
- ndp.invalidateDefaultRouter(ip)
+ // Prf (Default Router Preference)
+ //
+ // If the Reserved (10) value is received, the receiver MUST treat the
+ // value as if it were (00).
+ //
+ // Note that the value 00 is the medium (default) router preference value.
+ prf = header.MediumRoutePreference
}
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf)
}
// TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
@@ -815,55 +792,75 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
}
-// invalidateDefaultRouter invalidates a discovered default router.
+// invalidateOffLinkRoute invalidates a discovered off-link route.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
- rtr, ok := ndp.defaultRouters[ip]
-
- // Is the router still discovered?
+func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) {
+ state, ok := ndp.offLinkRoutes[route]
if !ok {
- // ...Nope, do nothing further.
return
}
- rtr.invalidationJob.Cancel()
- delete(ndp.defaultRouters, ip)
+ state.invalidationJob.Cancel()
+ delete(ndp.offLinkRoutes, route)
- // Let the integrator know a discovered default router is invalidated.
+ // Let the integrator know a discovered off-link route is invalidated.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
+ ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router)
}
}
-// rememberDefaultRouter remembers a newly discovered default router with IPv6
-// link-local address ip with lifetime rl.
-//
-// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
+// handleOffLinkRouteDiscovery handles the discovery of an off-link route.
//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) {
ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
- // Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
- // Informed by the integrator to not remember the router, do
- // nothing further.
- return
- }
+ state, ok := ndp.offLinkRoutes[route]
+ switch {
+ case !ok && lifetime != 0:
+ // This is a new route we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOffLinkRoutes routers.
+ if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes {
+ // Inform the integrator when we discovered an off-link route.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+
+ state := offLinkRouteState{
+ prf: prf,
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateOffLinkRoute(route)
+ }),
+ }
- state := defaultRouterState{
- invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
- ndp.invalidateDefaultRouter(ip)
- }),
- }
+ state.invalidationJob.Schedule(lifetime)
+
+ ndp.offLinkRoutes[route] = state
+ }
+
+ case ok && lifetime != 0:
+ // This is an already discovered off-link route. Update the lifetime.
+ state.invalidationJob.Cancel()
+ state.invalidationJob.Schedule(lifetime)
+
+ if prf != state.prf {
+ state.prf = prf
- state.invalidationJob.Schedule(rl)
+ // Inform the integrator about route preference updates.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+ }
+
+ ndp.offLinkRoutes[route] = state
- ndp.defaultRouters[ip] = state
+ case ok && lifetime == 0:
+ // The already discovered off-link route is no longer considered valid so we
+ // invalidate it immediately.
+ ndp.invalidateOffLinkRoute(route)
+ }
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -879,11 +876,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
- // Informed by the integrator to not remember the prefix, do
- // nothing further.
- return
- }
+ ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix)
state := onLinkPrefixState{
invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
@@ -1051,12 +1044,13 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
}
- now := time.Now()
+ now := ndp.ep.protocol.stack.Clock().NowMonotonic()
// The time an address is preferred until is needed to properly generate the
// address.
if pl < header.NDPInfiniteLifetime {
- state.preferredUntil = now.Add(pl)
+ t := now.Add(pl)
+ state.preferredUntil = &t
}
if !ndp.generateSLAACAddr(prefix, &state) {
@@ -1074,7 +1068,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if vl < header.NDPInfiniteLifetime {
state.invalidationJob.Schedule(vl)
- state.validUntil = now.Add(vl)
+ t := now.Add(vl)
+ state.validUntil = &t
}
// If the address is assigned (DAD resolved), generate a temporary address.
@@ -1097,16 +1092,13 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
- // Informed by the integrator not to add the address.
- return nil
- }
-
addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
+ ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr)
+
return addressEndpoint
}
@@ -1182,7 +1174,8 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
state.stableAddr.localGenerationFailures++
}
- if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ deprecated := state.preferredUntil != nil && !state.preferredUntil.After(ndp.ep.protocol.stack.Clock().NowMonotonic())
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, deprecated); addressEndpoint != nil {
state.stableAddr.addressEndpoint = addressEndpoint
state.generationAttempts++
return true
@@ -1237,13 +1230,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address
- now := time.Now()
+ now := ndp.ep.protocol.stack.Clock().NowMonotonic()
// As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
// address is the lower of the valid lifetime of the stable address or the
// maximum temporary address valid lifetime.
vl := ndp.configs.MaxTempAddrValidLifetime
- if prefixState.validUntil != (time.Time{}) {
+ if prefixState.validUntil != nil {
if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
vl = prefixVL
}
@@ -1259,7 +1252,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// maximum temporary address preferred lifetime - the temporary address desync
// factor.
pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
- if prefixState.preferredUntil != (time.Time{}) {
+ if prefixState.preferredUntil != nil {
if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
// Respect the preferred lifetime of the prefix, as per RFC 4941 section
// 3.3 step 4.
@@ -1394,16 +1387,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// deprecation job so it can be reset.
prefixState.deprecationJob.Cancel()
- now := time.Now()
+ now := ndp.ep.protocol.stack.Clock().NowMonotonic()
// Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
prefixState.deprecationJob.Schedule(pl)
}
- prefixState.preferredUntil = now.Add(pl)
+ t := now.Add(pl)
+ prefixState.preferredUntil = &t
} else {
- prefixState.preferredUntil = time.Time{}
+ prefixState.preferredUntil = nil
}
// As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
@@ -1421,17 +1415,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// Handle the infinite valid lifetime separately as we do not schedule a
// job in this case.
prefixState.invalidationJob.Cancel()
- prefixState.validUntil = time.Time{}
+ prefixState.validUntil = nil
} else {
var effectiveVl time.Duration
var rl time.Duration
// If the prefix was originally set to be valid forever, assume the
// remaining time to be the maximum possible value.
- if prefixState.validUntil == (time.Time{}) {
+ if prefixState.validUntil == nil {
rl = header.NDPInfiniteLifetime
} else {
- rl = time.Until(prefixState.validUntil)
+ rl = prefixState.validUntil.Sub(now)
}
if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
@@ -1443,7 +1437,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
if effectiveVl != 0 {
prefixState.invalidationJob.Cancel()
prefixState.invalidationJob.Schedule(effectiveVl)
- prefixState.validUntil = now.Add(effectiveVl)
+ t := now.Add(effectiveVl)
+ prefixState.validUntil = &t
}
}
@@ -1463,8 +1458,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// maximum temporary address valid lifetime. Note, the valid lifetime of a
// temporary address is relative to the address's creation time.
validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
- if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 {
- validUntil = prefixState.validUntil
+ if prefixState.validUntil != nil && prefixState.validUntil.Before(validUntil) {
+ validUntil = *prefixState.validUntil
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
@@ -1483,14 +1478,15 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// desync factor. Note, the preferred lifetime of a temporary address is
// relative to the address's creation time.
preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
- if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
- preferredUntil = prefixState.preferredUntil
+ if prefixState.preferredUntil != nil && prefixState.preferredUntil.Before(preferredUntil) {
+ preferredUntil = *prefixState.preferredUntil
}
// If the address is no longer preferred, deprecate it immediately.
// Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
tempAddrState.deprecationJob.Cancel()
+
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
} else {
@@ -1680,12 +1676,12 @@ func (ndp *ndpState) cleanupState() {
panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
}
- for router := range ndp.defaultRouters {
- ndp.invalidateDefaultRouter(router)
+ for route := range ndp.offLinkRoutes {
+ ndp.invalidateOffLinkRoute(route)
}
- if got := len(ndp.defaultRouters); got != 0 {
- panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ if got := len(ndp.offLinkRoutes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got))
}
ndp.dhcpv6Configuration = 0
@@ -1718,7 +1714,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// 4861 section 6.3.7.
var delay time.Duration
if ndp.configs.MaxRtrSolicitationDelay > 0 {
- delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ delay = time.Duration(ndp.ep.protocol.stack.Rand().Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
// Protected by ndp.ep.mu.
@@ -1754,7 +1750,7 @@ func (ndp *ndpState) startSolicitingRouters() {
header.NDPSourceLinkLayerAddressOption(linkAddress),
}
}
- payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length()
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
rs := header.NDPRouterSolicit(icmpData.MessageBody())
@@ -1848,21 +1844,19 @@ func (ndp *ndpState) stopSolicitingRouters() {
}
func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
- if ndp.defaultRouters != nil {
+ if ndp.offLinkRoutes != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions)
- ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
- if MaxDesyncFactor != 0 {
- ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
- }
+ ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor)))
}
func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 234e34952..f0186c64e 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -16,7 +16,7 @@ package ipv6
import (
"bytes"
- "context"
+ "math/rand"
"strings"
"testing"
"time"
@@ -32,58 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
-// setupStackAndEndpoint creates a stack with a single NIC with a link-local
-// address llladdr and an IPv6 endpoint to a remote with link-local address
-// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
-
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
- {
- subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
-
- ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
- t.Cleanup(ep.Close)
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := llladdr.WithPrefix()
- if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- addressEP.DecRef()
- }
-
- return s, ep
-}
-
var _ NDPDispatcher = (*testNDPDispatcher)(nil)
// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
@@ -94,24 +42,21 @@ type testNDPDispatcher struct {
func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) {
t.addr = addr
- return true
}
-func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) {
t.addr = addr
}
-func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
}
-func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return false
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
@@ -148,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
ipv6EP := ep.(*endpoint)
ipv6EP.mu.Lock()
- ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference)
ipv6EP.mu.Unlock()
if ndpDisp.addr != lladdr1 {
@@ -163,11 +108,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
}
}
-type linkResolutionResult struct {
- linkAddr tcpip.LinkAddress
- ok bool
-}
-
// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -456,8 +396,10 @@ func TestNeighborSolicitationResponse(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ Clock: clock,
})
e := channel.New(1, 1280, nicLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -527,7 +469,8 @@ func TestNeighborSolicitationResponse(t *testing.T) {
}
if test.performsLinkResolution {
- p, got := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, got := e.Read()
if !got {
t.Fatal("expected an NDP NS response")
}
@@ -582,7 +525,8 @@ func TestNeighborSolicitationResponse(t *testing.T) {
}))
}
- p, got := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, got := e.Read()
if !got {
t.Fatal("expected an NDP NA response")
}
@@ -906,12 +850,12 @@ func TestNDPValidation(t *testing.T) {
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp[:typ.size],
+ icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmpH[typ.size:], typ.extraData)
+ icmpH.SetType(typ.typ)
+ icmpH.SetCode(test.code)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH[:typ.size],
Src: lladdr0,
Dst: lladdr1,
PayloadCsum: header.Checksum(typ.extraData /* initial */, 0),
@@ -937,7 +881,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
+ handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
@@ -1297,8 +1241,9 @@ func TestCheckDuplicateAddress(t *testing.T) {
var secureRNG bytes.Reader
secureRNG.Reset(secureRNGBytes[:])
s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
- Clock: clock,
+ Clock: clock,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
DADConfigs: dadConfigs,
})},
@@ -1315,7 +1260,8 @@ func TestCheckDuplicateAddress(t *testing.T) {
snmc := header.SolicitedNodeAddr(lladdr0)
remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
checkDADMsg := func() {
- p, ok := e.ReadContext(context.Background())
+ clock.RunImmediatelyScheduledJobs()
+ p, ok := e.Read()
if !ok {
t.Fatalf("expected %d-th DAD message", dadPacketsSent)
}
@@ -1363,7 +1309,7 @@ func TestCheckDuplicateAddress(t *testing.T) {
t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err)
}
// Should not restart DAD since we already requested DAD above - the handler
- // should be called when the original request compeletes so we should not send
+ // should be called when the original request completes so we should not send
// an extra DAD message here.
dadRequestsMade++
if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index b5b013b64..854d6a6ba 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
// Wildcard destinations affect all destinations for TupleOnly.
if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
- intersection &= ((^TupleOnlyFlag) | counter.SharedFlags())
+ intersection &= (^TupleOnlyFlag) | counter.SharedFlags()
count++
}
}
@@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error)
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) {
+func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
- offset := uint32(rand.Int31n(int32(numEphemeral)))
+ offset := uint32(rng.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
@@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester)
// An optional PortTester can be passed in which if provided will be used to
// test if the picked port can be used. The function should return true if the
// port is safe to use, false otherwise.
-func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
+func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv
}
// A port wasn't specified, so try to find one.
- return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
if !pm.reserveSpecificPortLocked(res) {
return false, nil
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 6c4fb8c68..a91b130df 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -18,6 +18,7 @@ import (
"math"
"math/rand"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -331,6 +332,7 @@ func TestPortReservation(t *testing.T) {
t.Run(test.tname, func(t *testing.T) {
pm := NewPortManager()
net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for _, test := range test.actions {
first, _ := pm.PortRange()
@@ -356,7 +358,7 @@ func TestPortReservation(t *testing.T) {
BindToDevice: test.device,
Dest: test.dest,
}
- gotPort, err := pm.ReservePort(portRes, nil /* testPort */)
+ gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */)
if diff := cmp.Diff(test.want, err); diff != "" {
t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff)
}
@@ -417,10 +419,11 @@ func TestPickEphemeralPort(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
t.Fatalf("failed to set ephemeral port range: %s", err)
}
- port, err := pm.PickEphemeralPort(test.f)
+ port, err := pm.PickEphemeralPort(rng, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index b26936b7f..0ea85f9ed 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -222,7 +222,7 @@ type SocketOptions struct {
getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"`
// receiveBufferSize determines the receive buffer size for this socket.
- receiveBufferSize int64
+ receiveBufferSize atomicbitops.AlignedAtomicInt64
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -601,9 +601,10 @@ func (so *SocketOptions) GetBindToDevice() int32 {
return atomic.LoadInt32(&so.bindToDevice)
}
-// SetBindToDevice sets value for SO_BINDTODEVICE option.
+// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is
+// zero, the socket device binding is removed.
func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
- if !so.handler.HasNIC(bindToDevice) {
+ if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) {
return &ErrUnknownDevice{}
}
@@ -653,13 +654,13 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
// GetReceiveBufferSize gets value for SO_RCVBUF option.
func (so *SocketOptions) GetReceiveBufferSize() int64 {
- return atomic.LoadInt64(&so.receiveBufferSize)
+ return so.receiveBufferSize.Load()
}
// SetReceiveBufferSize sets value for SO_RCVBUF option.
func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
if !notify {
- atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize)
+ so.receiveBufferSize.Store(receiveBufferSize)
return
}
@@ -684,8 +685,8 @@ func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bo
v = math.MaxInt32
}
- oldSz := atomic.LoadInt64(&so.receiveBufferSize)
+ oldSz := so.receiveBufferSize.Load()
// Notify endpoint about change in buffer size.
newSz := so.handler.OnSetReceiveBufferSize(v, oldSz)
- atomic.StoreInt64(&so.receiveBufferSize, newSz)
+ so.receiveBufferSize.Store(newSz)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 84aa6a9e4..e0847e58a 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -56,6 +56,7 @@ go_library(
"neighbor_entry_list.go",
"neighborstate_string.go",
"nic.go",
+ "nic_stats.go",
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
@@ -94,7 +95,7 @@ go_library(
go_test(
name = "stack_x_test",
- size = "medium",
+ size = "small",
srcs = [
"addressable_endpoint_state_test.go",
"ndp_test.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 5720e7543..782e74b24 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -35,7 +35,6 @@ import (
// Currently, only TCP tracking is supported.
// Our hash table has 16K buckets.
-// TODO(gvisor.dev/issue/170): These should be tunable.
const numBuckets = 1 << 14
// Direction of the tuple.
@@ -125,6 +124,8 @@ type conn struct {
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
// is updated by each packet on the connection. It is protected by mu.
+ //
+ // TODO(gvisor.dev/issue/5939): do not use the ambient clock.
lastUsed time.Time `state:".(unixTime)"`
}
@@ -163,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
- // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
- // other tcp states.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
} else if hook == cn.tcbHook {
@@ -244,8 +243,7 @@ func (ct *ConnTrack) init() {
// connFor gets the conn for pkt if it exists, or returns nil
// if it does not. It returns an error when pkt does not contain a valid TCP
// header.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
-// other transport protocols.
+// TODO(gvisor.dev/issue/6168): Support UDP.
func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
tid, err := packetToTupleID(pkt)
if err != nil {
@@ -383,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ // TODO(gvisor.dev/issue/6168): Support UDP.
if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -407,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
+ var newAddr tcpip.Address
+ var newPort uint16
+
+ updateSRCFields := false
+
switch hook {
case Prerouting, Output:
if conn.manip == manipDestination {
switch dir {
case dirOriginal:
- tcpHeader.SetDestinationPort(conn.reply.srcPort)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
case dirReply:
- tcpHeader.SetSourcePort(conn.original.dstPort)
- netHeader.SetSourceAddress(conn.original.dstAddr)
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+
+ updateSRCFields = true
}
pkt.NatDone = true
}
@@ -424,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
if conn.manip == manipSource {
switch dir {
case dirOriginal:
- tcpHeader.SetSourcePort(conn.reply.dstPort)
- netHeader.SetSourceAddress(conn.reply.dstAddr)
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+
+ updateSRCFields = true
case dirReply:
- tcpHeader.SetDestinationPort(conn.original.srcPort)
- netHeader.SetDestinationAddress(conn.original.srcAddr)
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
}
pkt.NatDone = true
}
@@ -439,33 +446,33 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
+ fullChecksum := false
+ updatePseudoHeader := false
switch hook {
case Prerouting, Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
+ updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ fullChecksum = true
+ updatePseudoHeader = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
+ rewritePacket(
+ netHeader,
+ tcpHeader,
+ updateSRCFields,
+ fullChecksum,
+ updatePseudoHeader,
+ newPort,
+ newAddr,
+ )
// Update the state of tcb.
- // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
- // other tcp states.
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -542,8 +549,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused returns the next bucket that should be checked and the time after
// which it should be called again.
func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
- // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
- // it is in Linux via sysctl.
const fractionPerReaping = 128
const maxExpiredPct = 50
const maxFullTraversal = 60 * time.Second
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 7107d598d..72f66441f 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -114,10 +115,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
-func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
- return 0
-}
-
func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -134,7 +131,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
+func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
@@ -224,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {}
func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
if fn := f.proto.onLinkAddressResolved; fn != nil {
- time.AfterFunc(f.proto.addrResolveDelay, func() {
+ f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() {
fn(f.proto.neigh, addr, remoteLinkAddr)
})
}
@@ -319,7 +316,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -354,17 +351,19 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
}
// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) {
panic("not implemented")
}
-func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) {
+ clock := faketime.NewManualClock()
// Create a stack with the network protocol and two NICs.
s := New(Options{
NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
proto.stack = s
return proto
}},
+ Clock: clock,
})
protoNum := proto.Number()
@@ -373,7 +372,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}
// NIC 1 has the link address "a", and added the network address 1.
- ep1 = &fwdTestLinkEndpoint{
+ ep1 := &fwdTestLinkEndpoint{
C: make(chan fwdTestPacketInfo, 300),
mtu: fwdTestNetDefaultMTU,
linkAddr: "a",
@@ -386,7 +385,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}
// NIC 2 has the link address "b", and added the network address 2.
- ep2 = &fwdTestLinkEndpoint{
+ ep2 := &fwdTestLinkEndpoint{
C: make(chan fwdTestPacketInfo, 300),
mtu: fwdTestNetDefaultMTU,
linkAddr: "b",
@@ -416,7 +415,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
}
- return ep1, ep2
+ return clock, ep1, ep2
}
func TestForwardingWithStaticResolver(t *testing.T) {
@@ -432,7 +431,7 @@ func TestForwardingWithStaticResolver(t *testing.T) {
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, proto)
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -444,6 +443,7 @@ func TestForwardingWithStaticResolver(t *testing.T) {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
default:
@@ -475,7 +475,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -487,9 +487,10 @@ func TestForwardingWithFakeResolver(t *testing.T) {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -508,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) {
// Whether or not we use the neighbor cache here does not matter since
// neither linkAddrCache nor neighborCache will be used.
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, proto)
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -518,10 +519,11 @@ func TestForwardingWithNoResolver(t *testing.T) {
Data: buf.ToVectorisedView(),
}))
+ clock.Advance(proto.addrResolveDelay)
select {
case <-ep2.C:
t.Fatal("Packet should not be forwarded")
- case <-time.After(time.Second):
+ default:
}
}
@@ -533,7 +535,7 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, proto)
const numPackets int = 5
// These packets will all be enqueued in the packet queue to wait for link
@@ -547,12 +549,12 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
}
// All packets should fail resolution.
- // TODO(gvisor.dev/issue/5141): Use a fake clock.
for i := 0; i < numPackets; i++ {
+ clock.Advance(proto.addrResolveDelay)
select {
case got := <-ep2.C:
t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
- case <-time.After(100 * time.Millisecond):
+ default:
}
}
}
@@ -576,7 +578,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
}
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
@@ -596,9 +598,10 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -631,7 +634,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
@@ -645,9 +648,10 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
for i := 0; i < 2; i++ {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -681,7 +685,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
@@ -697,9 +701,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
for i := 0; i < maxPendingPacketsPerResolution; i++ {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
@@ -745,7 +750,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
})
},
}
- ep1, ep2 := fwdTestNetFactory(t, &proto)
+ clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
for i := 0; i < maxPendingResolutions+5; i++ {
// Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
@@ -761,9 +766,10 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
for i := 0; i < maxPendingResolutions; i++ {
var p fwdTestPacketInfo
+ clock.Advance(proto.addrResolveDelay)
select {
case p = <-ep2.C:
- case <-time.After(time.Second):
+ default:
t.Fatal("packet not forwarded")
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 3670d5995..f152c0d83 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() *IPTables {
+func DefaultTables(seed uint32) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
@@ -182,7 +182,7 @@ func DefaultTables() *IPTables {
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
- seed: generateRandUint32(),
+ seed: seed,
},
reaperDone: make(chan struct{}, 1),
}
@@ -268,10 +268,6 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
-// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from
-// which address can be gathered. Currently, address is only needed for
-// prerouting.
-//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
@@ -371,6 +367,7 @@ func (it *IPTables) startReaper(interval time.Duration) {
select {
case <-it.reaperDone:
return
+ // TODO(gvisor.dev/issue/5939): do not use the ambient clock.
case <-time.After(interval):
bucket, interval = it.connections.reapUnused(bucket, interval)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 2812c89aa..96cc899bb 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
// forwarded).
-//
-// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
-// them.
type RedirectTarget struct {
// Port indicates port used to redirect. It is immutable.
Port uint16
@@ -100,9 +97,6 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for Prerouting and calls pkt.Clone(), neither
-// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
@@ -136,34 +130,26 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
panic("redirect target is supported only on output and prerouting hooks")
}
- // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
- // we need to change dest address (for OUTPUT chain) or ports.
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.Port)
- // Calculate UDP checksum and set it.
if hook == Output {
- udpHeader.SetChecksum(0)
- netHeader := pkt.Network()
- netHeader.SetDestinationAddress(address)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ udpHeader,
+ false, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ rt.Port,
+ address,
+ )
+ } else {
+ udpHeader.SetDestinationPort(rt.Port)
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -222,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetChecksum(0)
- udpHeader.SetSourcePort(st.Port)
- netHeader := pkt.Network()
- netHeader.SetSourceAddress(st.Addr)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ header.UDP(pkt.TransportHeader().View()),
+ true, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ st.Port,
+ st.Addr,
+ )
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -260,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleAccept, 0
}
+
+func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
+ if updateSRCFields {
+ if fullChecksum {
+ t.SetSourcePortWithChecksumUpdate(newPort)
+ } else {
+ t.SetSourcePort(newPort)
+ }
+ } else {
+ if fullChecksum {
+ t.SetDestinationPortWithChecksumUpdate(newPort)
+ } else {
+ t.SetDestinationPort(newPort)
+ }
+ }
+
+ if updatePseudoHeader {
+ var oldAddr tcpip.Address
+ if updateSRCFields {
+ oldAddr = n.SourceAddress()
+ } else {
+ oldAddr = n.DestinationAddress()
+ }
+
+ t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum)
+ }
+
+ if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok {
+ if updateSRCFields {
+ checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr)
+ } else {
+ checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr)
+ }
+ } else if updateSRCFields {
+ n.SetSourceAddress(newAddr)
+ } else {
+ n.SetDestinationAddress(newAddr)
+ }
+}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 93592e7f5..66e5f22ac 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -242,7 +242,6 @@ type IPHeaderFilter struct {
func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
- // TODO(gvisor.dev/issue/170): Support other filter fields.
transProto tcpip.TransportProtocolNumber
dstAddr tcpip.Address
srcAddr tcpip.Address
@@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa
return true
case Postrouting:
- // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING.
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index ac2fa777e..9623d9c28 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -16,14 +16,14 @@ package stack_test
import (
"bytes"
- "context"
"encoding/binary"
"fmt"
+ "math/rand"
"testing"
"time"
"github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
+ cryptorand "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -52,17 +52,6 @@ const (
linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
defaultPrefixLen = 128
-
- // Extra time to use when waiting for an async event to occur.
- defaultAsyncPositiveEventTimeout = 10 * time.Second
-
- // Extra time to use when waiting for an async event to not occur.
- //
- // Since a negative check is used to make sure an event did not happen, it is
- // okay to use a smaller timeout compared to the positive case since execution
- // stall in regards to the monotonic clock will not affect the expected
- // outcome.
- defaultAsyncNegativeEventTimeout = time.Second
)
var (
@@ -112,11 +101,13 @@ type ndpDADEvent struct {
res stack.DADResult
}
-type ndpRouterEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- // true if router was discovered, false if invalidated.
- discovered bool
+type ndpOffLinkRouteEvent struct {
+ nicID tcpip.NICID
+ subnet tcpip.Subnet
+ router tcpip.Address
+ prf header.NDPRoutePreference
+ // true if route was updated, false if invalidated.
+ updated bool
}
type ndpPrefixEvent struct {
@@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct {
eventType ndpAutoGenAddrEventType
}
+func (e ndpAutoGenAddrEvent) String() string {
+ return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType)
+}
+
type ndpRDNSS struct {
addrs []tcpip.Address
lifetime time.Duration
@@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
// related events happen for test purposes.
type ndpDispatcher struct {
dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
+ offLinkRouteC chan ndpOffLinkRouteEvent
prefixC chan ndpPrefixEvent
- rememberPrefix bool
autoGenAddrC chan ndpAutoGenAddrEvent
rdnssC chan ndpRDNSSEvent
dnsslC chan ndpDNSSLEvent
@@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add
}
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered.
-func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated.
+func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) {
+ if c := n.offLinkRouteC; c != nil {
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
+ prf,
true,
}
}
-
- return n.rememberRouter
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated.
-func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated.
+func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) {
+ if c := n.offLinkRouteC; c != nil {
+ var prf header.NDPRoutePreference
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
+ prf,
false,
}
}
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
-func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
nicID,
@@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip
true,
}
}
-
- return n.rememberPrefix
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
@@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi
}
}
-func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
if c := n.autoGenAddrC; c != nil {
c <- ndpAutoGenAddrEvent{
nicID,
@@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi
newAddr,
}
}
- return true
}
func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
@@ -497,8 +490,9 @@ func TestDADResolve(t *testing.T) {
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
- Clock: clock,
- SecureRNG: &secureRNG,
+ Clock: clock,
+ RandSource: rand.NewSource(time.Now().UnixNano()),
+ SecureRNG: &secureRNG,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: stack.DADConfigurations{
@@ -605,7 +599,10 @@ func TestDADResolve(t *testing.T) {
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p, _ := e.ReadContext(context.Background())
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("packet didn't arrive")
+ }
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
@@ -731,11 +728,13 @@ func TestDADFail(t *testing.T) {
dadConfigs.RetransmitTimer = time.Second * 2
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: dadConfigs,
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -761,16 +760,17 @@ func TestDADFail(t *testing.T) {
// Wait for DAD to fail and make sure the address did
// not get resolved.
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
select {
- case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the
- // expected resolution time + extra 1s buffer,
- // something is wrong.
- t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Fatal(err)
@@ -839,11 +839,13 @@ func TestDADStop(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
DADConfigs: dadConfigs,
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
@@ -861,15 +863,16 @@ func TestDADStop(t *testing.T) {
test.stopFn(t, s)
// Wait for DAD to fail (since the address was removed during DAD).
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
select {
- case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
}
if !test.skipFinalAddrCheck {
@@ -920,10 +923,12 @@ func TestSetNDPConfigurations(t *testing.T) {
dadC: make(chan ndpDADEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
@@ -1002,28 +1007,23 @@ func TestSetNDPConfigurations(t *testing.T) {
t.Fatal(err)
}
- // Sleep until right (500ms before) before resolution to
- // make sure the address didn't resolve on NIC(1) yet.
- const delta = 500 * time.Millisecond
- time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
+ // Sleep until right before resolution to make sure the address didn't
+ // resolve on NIC(1) yet.
+ const delta = 1
+ clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Fatal(err)
}
// Wait for DAD to resolve.
+ clock.Advance(delta)
select {
- case <-time.After(2 * delta):
- // We should get a resolution event after 500ms
- // (delta) since we wait for 500ms less than the
- // expected resolution time above to make sure
- // that the address did not yet resolve. Waiting
- // for 1s (2x delta) without a resolution event
- // means something is wrong.
- t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("timed out waiting for DAD resolution")
}
if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
t.Fatal(err)
@@ -1032,10 +1032,13 @@ func TestSetNDPConfigurations(t *testing.T) {
}
}
-// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
-// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
+// raBuf returns a valid NDP Router Advertisement with options, router
+// preference and DHCPv6 configurations specified.
+func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+ const flagsByte = 1
+ const routerLifetimeOffset = 2
+
+ icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
@@ -1043,19 +1046,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
- binary.BigEndian.PutUint16(raPayload[2:], rl)
+ binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl)
// Populate the Managed Address flag field.
if managedAddress {
- // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
- // of the RA payload.
- raPayload[1] |= (1 << 7)
+ // The Managed Addresses flag field is the 7th bit of the flags byte.
+ raPayload[flagsByte] |= 1 << 7
}
// Populate the Other Configurations flag field.
if otherConfigurations {
- // The Other Configurations flag field is the 6th bit of byte #1
- // (0-indexing) of the RA payload.
- raPayload[1] |= (1 << 6)
+ // The Other Configurations flag field is the 6th bit of the flags byte.
+ raPayload[flagsByte] |= 1 << 6
}
+ // The Prf field is held in the flags byte.
+ raPayload[flagsByte] |= byte(prf) << 3
opts := ra.Options()
opts.Serialize(optSer)
pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -1083,7 +1086,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+ return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer)
}
// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
@@ -1091,18 +1094,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer {
+ return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{})
}
// raBuf returns a valid NDP Router Advertisement.
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
+func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
+// raBufWithPrf returns a valid NDP Router Advertisement with a preference.
+//
+// Note, raBufWithPrf does not populate any of the RA fields other than the
+// Router Lifetime and Default Router Preference fields.
+func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{})
+}
+
// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
// Information option.
//
@@ -1162,7 +1173,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
config: func(enable bool) ipv6.NDPConfigurations {
return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable}
},
- ra: raBuf(llAddr2, 1000),
+ ra: raBufSimple(llAddr2, 1000),
},
{
name: "No Prefix Discovery",
@@ -1198,9 +1209,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- prefixC: make(chan ndpPrefixEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
ndpConfigs := test.config(enable)
ndpConfigs.HandleRAs = handle
@@ -1270,8 +1281,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpectedly discovered a router when configured not to: %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e)
default:
}
select {
@@ -1297,10 +1308,8 @@ func boolToUint64(v bool) uint64 {
return 0
}
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// discovered flag set to discovered.
-func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
- return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string {
+ return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e))
}
func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
@@ -1333,56 +1342,15 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b
}
}
-// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered router when the dispatcher asks it not to.
-func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA for a router we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the router in the first place.
- select {
- case <-ndpDisp.routerC:
- t.Fatal("should not have received any router events")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
- }
-}
-
func TestRouterDiscovery(t *testing.T) {
+ const nicID = 1
+
testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1391,30 +1359,33 @@ func TestRouterDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) {
t.Helper()
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected router discovery event")
}
}
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ var prf header.NDPRoutePreference
+ if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for router discovery event")
}
}
@@ -1423,37 +1394,44 @@ func TestRouterDiscovery(t *testing.T) {
t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
}
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Rx an RA from lladdr2 with zero lifetime. It should not be
// remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0))
select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with 0 lifetime")
default:
}
- // Rx an RA from lladdr2 with a huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ // Rx an RA from lladdr2 with a huge lifetime and reserved preference value
+ // (which should be interpreted as the default (medium) preference value).
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
- // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ // Rx an RA from another router (lladdr3) with non-zero lifetime and
+ // non-default preference value.
const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference))
+ expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true)
- // Rx an RA from lladdr2 with lesser lifetime.
+ // Rx an RA from lladdr2 with lesser lifetime and default (medium)
+ // preference value.
const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds))
select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("should not receive a off-link route event when updating lifetimes for known routers")
default:
}
+ // Rx an RA from lladdr2 with a different preference.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true)
+
// Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
@@ -1461,15 +1439,15 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
// Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
// Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false)
// Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
@@ -1478,16 +1456,17 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
})
}
// TestRouterDiscoveryMaxRouters tests that only
-// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
+// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ const nicID = 1
+
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -1500,23 +1479,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
})},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ {
+ for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
linkAddr[5] = byte(i)
llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5))
- if i <= ipv6.MaxDiscoveredDefaultRouters {
+ if i <= ipv6.MaxDiscoveredOffLinkRoutes {
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected router discovery event")
@@ -1524,7 +1503,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
} else {
select {
- case <-ndpDisp.routerC:
+ case <-ndpDisp.offLinkRouteC:
t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
default:
}
@@ -1538,51 +1517,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st
return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
}
-// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered on-link prefix when the dispatcher asks it not to.
-func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- prefix, subnet, _ := prefixSubnetAddr(0, "")
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with prefix that we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, true); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the prefix in the first place.
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("should not have received any prefix events")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
- }
-}
-
func TestPrefixDiscovery(t *testing.T) {
prefix1, subnet1, _ := prefixSubnetAddr(0, "")
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
@@ -1590,10 +1524,10 @@ func TestPrefixDiscovery(t *testing.T) {
testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1602,6 +1536,7 @@ func TestPrefixDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1662,12 +1597,13 @@ func TestPrefixDiscovery(t *testing.T) {
// Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
+ clock.Advance(time.Duration(lifetime) * time.Second)
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1678,17 +1614,6 @@ func TestPrefixDiscovery(t *testing.T) {
}
func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
- // Update the infinite lifetime value to a smaller value so we can test
- // that when we receive a PI with such a lifetime value, we do not
- // invalidate the prefix.
- const testInfiniteLifetimeSeconds = 2
- const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
- saved := header.NDPInfiniteLifetime
- header.NDPInfiniteLifetime = testInfiniteLifetime
- defer func() {
- header.NDPInfiniteLifetime = saved
- }()
-
prefix := tcpip.AddressWithPrefix{
Address: testutil.MustParse6("102:304:506:708::"),
PrefixLen: 64,
@@ -1696,10 +1621,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
subnet := prefix.Subnet()
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1708,6 +1633,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1729,46 +1655,39 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with prefix in an NDP Prefix Information option (PI)
// with infinite valid lifetime which should not get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
expectPrefixEvent(subnet, true)
+ clock.Advance(header.NDPInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
+ default:
}
// Receive an RA with finite lifetime.
- // The prefix should get invalidated after 1s.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
+ clock.Advance(header.NDPInfiniteLifetime - time.Second)
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(testInfiniteLifetime):
+ default:
t.Fatal("timed out waiting for prefix discovery event")
}
// Receive an RA with finite lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
expectPrefixEvent(subnet, true)
// Receive an RA with prefix with an infinite lifetime.
// The prefix should not be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
- }
-
- // Receive an RA with a prefix with a lifetime value greater than the
- // set infinite lifetime value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
+ clock.Advance(header.NDPInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
+ default:
}
// Receive an RA with 0 lifetime.
@@ -1781,8 +1700,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -1859,17 +1777,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
}
+const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second)
+const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second)
+
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
func TestAutoGenAddr(t *testing.T) {
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
-
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1878,6 +1791,7 @@ func TestAutoGenAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1886,6 +1800,7 @@ func TestAutoGenAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
@@ -1935,8 +1850,9 @@ func TestAutoGenAddr(t *testing.T) {
default:
}
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
+ // the minimum.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0))
expectAutoGenAddrEvent(addr2, newAddr)
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -1946,7 +1862,7 @@ func TestAutoGenAddr(t *testing.T) {
}
// Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
@@ -1954,12 +1870,13 @@ func TestAutoGenAddr(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1989,20 +1906,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t
// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when
// configured to do so as part of IPv6 Privacy Extensions.
func TestAutoGenTempAddr(t *testing.T) {
- const (
- nicID = 1
- newMinVL = 5
- newMinVLDuration = newMinVL * time.Second
- )
-
- savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate
- savedMaxDesync := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
- ipv6.MaxDesyncFactor = savedMaxDesync
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- ipv6.MaxDesyncFactor = time.Nanosecond
+ const nicID = 1
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -2022,218 +1926,211 @@ func TestAutoGenTempAddr(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for i, test := range tests {
- i := i
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- seed := []byte{uint8(i)}
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], seed, nicID)
- newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
- return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
- }
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 2),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
- },
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- TempIIDSeed: seed,
- })},
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ for i, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ seed := []byte{uint8(i)}
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], seed, nicID)
+ newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 2),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })},
+ Clock: clock,
+ })
- expectDADEventAsync := func(addr tcpip.Address) {
- t.Helper()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD event")
- }
- }
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("expected addr auto gen event")
}
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectDADEventAsync(addr1.Address)
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for addr auto gen event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid & preferred lifetimes.
- tempAddr1 := newTempAddr(addr1.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- expectDADEventAsync(tempAddr1.Address)
- if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for DAD event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ }
- // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred
- // lifetimes.
- tempAddr2 := newTempAddr(addr2.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- expectDADEventAsync(addr2.Address)
- expectAutoGenAddrEventAsync(tempAddr2, newAddr)
- expectDADEventAsync(tempAddr2.Address)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ default:
+ }
- // Deprecate prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectDADEventAsync(addr1.Address)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Refresh lifetimes for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ tempAddr1 := newTempAddr(addr1.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ expectDADEventAsync(tempAddr1.Address)
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Reduce valid lifetime and deprecate addresses of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Wait for addrs of prefix1 to be invalidated. They should be
- // invalidated at the same time.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- var nextAddr tcpip.AddressWithPrefix
- if e.addr == addr1 {
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = tempAddr1
- } else {
- if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = addr1
- }
+ // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
+ // the minimum and won't be reached in this test.
+ tempAddr2 := newTempAddr(addr2.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ expectDADEventAsync(addr2.Address)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr)
+ expectDADEventAsync(tempAddr2.Address)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
+ // Deprecate prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Refresh lifetimes for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Reduce valid lifetime and deprecate addresses of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for addrs of prefix1 to be invalidated. They should be
+ // invalidated at the same time.
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ var nextAddr tcpip.AddressWithPrefix
+ if e.addr == addr1 {
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
+ nextAddr = tempAddr1
+ } else {
+ if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = addr1
}
- // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unexpected auto gen addr event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for addr auto gen event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
- }
- })
- }
- })
+ default:
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ })
+ }
}
// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not
@@ -2241,12 +2138,6 @@ func TestAutoGenTempAddr(t *testing.T) {
func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
const nicID = 1
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- }()
- ipv6.MaxDesyncFactor = time.Nanosecond
-
tests := []struct {
name string
dupAddrTransmits uint8
@@ -2262,66 +2153,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- AutoGenLinkLocal: true,
- })},
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
+ })},
+ Clock: clock,
+ })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // The stable link-local address should auto-generate and resolve DAD.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
+ // The stable link-local address should auto-generate and resolve DAD.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD event")
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("timed out waiting for DAD event")
+ }
- // No new addresses should be generated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unxpected auto gen addr event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
- }
- })
- }
- })
+ // No new addresses should be generated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unxpected auto gen addr event = %+v", e)
+ default:
+ }
+ })
+ }
}
// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address
@@ -2334,12 +2215,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
retransmitTimer = 2 * time.Second
)
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- }()
- ipv6.MaxDesyncFactor = 0
-
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
@@ -2350,6 +2225,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
@@ -2363,6 +2239,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2392,12 +2269,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// Wait for DAD to complete for the stable address then expect the temporary
// address to be generated.
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
select {
@@ -2405,7 +2283,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2414,46 +2292,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// regenerated.
func TestAutoGenTempAddrRegen(t *testing.T) {
const (
- nicID = 1
- regenAfter = 2 * time.Second
- newMinVL = 10
- newMinVLDuration = newMinVL * time.Second
- )
+ nicID = 1
+ regenAdv = 2 * time.Second
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
- }()
- ipv6.MaxDesyncFactor = 0
- ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
+ numTempAddrs = 3
+ maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate
+ )
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
+ for i := 0; i < len(tempAddrs); i++ {
+ tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ }
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: regenAdv,
+ MaxTempAddrValidLifetime: maxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ }
+ clock := faketime.NewManualClock()
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ndpConfigs,
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
+ RandSource: &randSource,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2476,36 +2352,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
+ tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
+ effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor
+ // The time since the last regeneration before a new temporary address is
+ // generated.
+ tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
+
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ expectAutoGenAddrEvent(tempAddrs[0], newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
+ expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
+ expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv)
+ expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
// Stop generating temporary addresses
ndpConfigs.AutoGenTempGlobalAddresses = false
@@ -2516,45 +2399,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
ndpEP.SetNDPConfigurations(ndpConfigs)
}
+ // Refresh lifetimes and wait for the last temporary address to be deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
+ expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv)
+
+ // Refresh lifetimes such that the prefix is valid and preferred forever.
+ //
+ // This should not affect the lifetimes of temporary addresses because they
+ // are capped by the maximum valid and preferred lifetimes for temporary
+ // addresses.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
+
// Wait for all the temporary addresses to get invalidated.
- tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3}
- invalidateAfter := newMinVLDuration - 2*regenAfter
+ invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{})
for _, addr := range tempAddrs {
- // Wait for a deprecation then invalidation event, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation jobs could execute in any
- // order.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we shouldn't get a deprecation
- // event after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
- }
- } else {
- t.Fatalf("got unexpected auto-generated event = %+v", e)
- }
- case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
-
- invalidateAfter = regenAfter
+ expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter)
+ invalidateAfter = tempAddrRegenenerationTime
}
- if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" {
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" {
t.Fatal(mismatch)
}
}
@@ -2563,52 +2425,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
// regeneration job gets updated when refreshing the address's lifetimes.
func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
- nicID = 1
- regenAfter = 2 * time.Second
- newMinVL = 10
- newMinVLDuration = newMinVL * time.Second
- )
+ nicID = 1
+ regenAdv = 2 * time.Second
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
- }()
- ipv6.MaxDesyncFactor = 0
- ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
+ numTempAddrs = 3
+ maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate
+ maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second)
+ )
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
+ for i := 0; i < len(tempAddrs); i++ {
+ tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ }
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: regenAdv,
+ MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime,
+ MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2,
+ }
+ clock := faketime.NewManualClock()
+ initialTime := clock.NowMonotonic()
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ndpConfigs,
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
+ RandSource: &randSource,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
+
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -2625,22 +2489,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ expectAutoGenAddrEvent(tempAddrs[0], newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -2648,13 +2513,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
//
// A new temporary address should be generated after the regeneration
// time has passed since the prefix is deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0))
expectAutoGenAddrEvent(addr, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ t.Fatalf("unexpected auto gen addr event = %#v", e)
+ default:
+ }
+
+ effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor
+ // The time since the last regeneration before a new temporary address is
+ // generated.
+ tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
+
+ // Advance the clock by the regeneration time but don't expect a new temporary
+ // address as the prefix is deprecated.
+ clock.Advance(tempAddrRegenenerationTime)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %#v", e)
+ default:
}
// Prefer the prefix again.
@@ -2662,8 +2541,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
// - this regeneration does not depend on a job.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEvent(tempAddr2, newAddr)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
+ expectAutoGenAddrEvent(tempAddrs[1], newAddr)
+ // Wait for the first temporary address to be deprecated.
+ expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %s", e)
+ default:
+ }
// Increase the maximum lifetimes for temporary addresses to large values
// then refresh the lifetimes of the prefix.
@@ -2674,34 +2560,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// regenerate a new temporary address. Note, new addresses are only
// regenerated after the preferred lifetime - the regenerate advance duration
// as paased.
- ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
- ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
+ const largeLifetimeSeconds = minVLSeconds * 2
+ const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second
+ ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime
+ ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime
ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
ndpEP.SetNDPConfigurations(ndpConfigs)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime)
+ clock.Advance(largeLifetime - timeSinceInitialTime)
+ expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
+ // to offset the advement of time to test the first temporary address's
+ // deprecation after the second was generated
+ advLess := regenAdv
+ expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv))
+ expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ default:
}
-
- // Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration job gets scheduled again.
- //
- // The maximum lifetime is the sum of the minimum lifetimes for temporary
- // addresses + the time that has already passed since the last address was
- // generated so that the regeneration job is needed to generate the next
- // address.
- newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
- ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
- ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
- ndpEP.SetNDPConfigurations(ndpConfigs)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
@@ -2929,13 +2811,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// stack.Stack will have a default route through the router (llAddr3) installed
// and a static link-address (linkAddr3) added to the link address cache for the
// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
t.Helper()
ndpDisp := &ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -2945,6 +2828,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
NDPDisp: ndpDisp,
})},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2958,7 +2842,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil {
t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err)
}
- return ndpDisp, e, s
+ return ndpDisp, e, s, clock
}
// addrForNewConnectionTo returns the local address used when creating a new
@@ -3032,7 +2916,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3135,19 +3019,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// when its preferred lifetime expires.
func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
-
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3165,12 +3041,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -3188,7 +3065,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
expectAutoGenAddrEvent(addr2, newAddr)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
@@ -3207,7 +3084,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3216,7 +3093,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3226,6 +3103,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// addr2 should be the primary endpoint now since addr1 is deprecated but
// addr2 is not.
expectPrimaryAddr(addr2)
+
// addr1 is deprecated but if explicitly requested, it should be used.
fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
@@ -3234,7 +3112,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
// sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3246,7 +3124,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3256,7 +3134,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3270,7 +3148,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second)
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3280,7 +3158,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr2)
// Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3292,6 +3170,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// cases because the deprecation and invalidation handlers could be handled in
// either deprecation then invalidation, or invalidation then deprecation
// (which should be cancelled by the invalidation handler).
+ //
+ // Since we're about to cause both events to fire, we need the dispatcher
+ // channel to be able to hold both.
+ if got, want := len(ndpDisp.autoGenAddrC), 0; got != want {
+ t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
+ }
+ if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want {
+ t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
+ }
+ ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2)
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
@@ -3302,21 +3191,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
+ // If we get an invalidation event first, we should not get a deprecation
// event after.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
} else {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3353,15 +3242,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// infinite values.
func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
const infiniteVLSeconds = 2
- const minVLSeconds = 1
- savedIL := header.NDPInfiniteLifetime
- savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
- header.NDPInfiniteLifetime = savedIL
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
- header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3385,68 +3265,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- // Receive an RA with finite prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- default:
- t.Fatal("expected addr auto gen event")
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- // Receive an new RA with prefix with infinite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
- // Receive a new RA with prefix with finite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- })
- }
- })
+
+ default:
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
}
// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
@@ -3454,12 +3324,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
// RFC 4862 section 5.5.3.e.
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
- const newMinVL = 4
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3470,137 +3334,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
evl uint32
}{
// Should update the VL to the minimum VL for updating if the
- // new VL is less than newMinVL but was originally greater than
+ // new VL is less than minVLSeconds but was originally greater than
// it.
{
"LargeVLToVLLessThanMinVLForUpdate",
9999,
1,
- newMinVL,
+ minVLSeconds,
},
{
"LargeVLTo0",
9999,
0,
- newMinVL,
+ minVLSeconds,
},
{
"InfiniteVLToVLLessThanMinVLForUpdate",
infiniteVL,
1,
- newMinVL,
+ minVLSeconds,
},
{
"InfiniteVLTo0",
infiniteVL,
0,
- newMinVL,
+ minVLSeconds,
},
- // Should not update VL if original VL was less than newMinVL
- // and the new VL is also less than newMinVL.
+ // Should not update VL if original VL was less than minVLSeconds
+ // and the new VL is also less than minVLSeconds.
{
"ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
- newMinVL - 1,
- newMinVL - 3,
- newMinVL - 1,
+ minVLSeconds - 1,
+ minVLSeconds - 3,
+ minVLSeconds - 1,
},
// Should take the new VL if the new VL is greater than the
- // remaining time or is greater than newMinVL.
+ // remaining time or is greater than minVLSeconds.
{
"MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
- newMinVL + 5,
- newMinVL + 3,
- newMinVL + 3,
+ minVLSeconds + 5,
+ minVLSeconds + 3,
+ minVLSeconds + 3,
},
{
"SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
- newMinVL - 3,
- newMinVL - 1,
- newMinVL - 1,
+ minVLSeconds - 3,
+ minVLSeconds - 1,
+ minVLSeconds - 1,
},
{
"SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
- newMinVL - 3,
- newMinVL + 1,
- newMinVL + 1,
+ minVLSeconds - 3,
+ minVLSeconds + 1,
+ minVLSeconds + 1,
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
- }
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- // Receive an RA with prefix with initial VL,
- // test.ovl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
+ // Receive an RA with prefix with initial VL,
+ // test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
- // Receive an new RA with prefix with new VL,
- // test.nvl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+ // Receive an new RA with prefix with new VL,
+ // test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
- //
- // Validate that the VL for the address got set
- // to test.evl.
- //
+ //
+ // Validate that the VL for the address got set
+ // to test.evl.
+ //
- // The address should not be invalidated until the effective valid
- // lifetime has passed.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
- }
+ // The address should not be invalidated until the effective valid
+ // lifetime has passed.
+ const delta = 1
+ clock.Advance(time.Duration(test.evl)*time.Second - delta)
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ default:
+ }
- // Wait for the invalidation event.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
+ // Wait for the invalidation event.
+ clock.Advance(delta)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- })
- }
- })
+ default:
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
}
// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
@@ -3613,6 +3469,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -3621,6 +3478,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3654,10 +3512,11 @@ func TestAutoGenAddrRemoval(t *testing.T) {
// Wait for the original valid lifetime to make sure the original job got
// cancelled/cleaned up.
+ clock.Advance(lifetimeSeconds * time.Second)
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ default:
}
}
@@ -3668,7 +3527,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3779,6 +3638,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -3787,6 +3647,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3816,30 +3677,36 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
// Should not get an invalidation event after the PI's invalidation
// time.
+ clock.Advance(lifetimeSeconds * time.Second)
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ default:
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
}
+func makeSecretKey(t *testing.T) []byte {
+ secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes)
+ n, err := cryptorand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("cryptorand.Read(_): %s", err)
+ }
+ if l := len(secretKey); n != l {
+ t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l)
+ }
+ return secretKey
+}
+
// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
// opaque interface identifiers when configured to do so.
func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
const nicID = 1
const nicName = "nic1"
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
- secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
- }
+
+ secretKey := makeSecretKey(t)
prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
@@ -3861,6 +3728,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -3875,6 +3743,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
SecretKey: secretKey,
},
})},
+ Clock: clock,
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -3913,12 +3782,13 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
+ clock.Advance(validLifetimeSecondPrefix1 * time.Second)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3937,22 +3807,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
const maxMaxRetries = 3
const lifetimeSeconds = 10
- // Needed for the temporary address sub test.
- savedMaxDesync := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesync
- }()
- ipv6.MaxDesyncFactor = time.Nanosecond
-
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
- secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
- }
+ secretKey := makeSecretKey(t)
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
@@ -3977,22 +3832,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
- expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
+ expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
@@ -4003,15 +3860,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
+ expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
}
@@ -4022,7 +3880,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
name string
ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
- prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
+ prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
}{
{
@@ -4031,7 +3889,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
- prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
// Receive an RA with prefix1 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
return nil
@@ -4045,7 +3903,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
name: "LinkLocal address",
ndpConfigs: ipv6.NDPConfigurations{},
autoGenLinkLocal: true,
- prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
return nil
},
addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
@@ -4059,14 +3917,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
- prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
header.InitialTempIID(tempIIDHistory, nil, nicID)
// Generate a stable SLAAC address so temporary addresses will be
// generated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
- expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
+ expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
// The stable address will be assigned throughout the test.
return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
@@ -4078,14 +3936,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
for _, addrType := range addrTypes {
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the parallel
- // tests complete and limit the number of parallel tests running at the same
- // time to reduce flakes.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
t.Run(addrType.name, func(t *testing.T) {
for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
@@ -4094,8 +3944,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
addrType := addrType
t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
@@ -4103,6 +3951,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := addrType.ndpConfigs
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: addrType.autoGenLinkLocal,
@@ -4119,6 +3968,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
SecretKey: secretKey,
},
})},
+ Clock: clock,
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4126,12 +3976,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
var tempIIDHistory [header.IIDSize]byte
- stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:])
+ stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:])
// Simulate DAD conflicts so the address is regenerated.
for i := uint8(0); i < numFailures; i++ {
addr := addrType.addrGenFn(i, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+ expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
// Should not have any new addresses assigned to the NIC.
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
@@ -4141,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Simulate a DAD conflict.
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
- expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
+ expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
@@ -4151,7 +4001,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
}
- expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{})
+ expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{})
}
// Should not have any new addresses assigned to the NIC.
@@ -4163,8 +4013,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// an address after DAD resolves.
if maxRetries+1 > numFailures {
addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
- expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{})
+ expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
+ expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{})
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -4174,7 +4024,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
})
}
@@ -4231,13 +4081,12 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
addrType := addrType
t.Run(addrType.name, func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: addrType.autoGenLinkLocal,
@@ -4248,6 +4097,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4292,7 +4142,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
})
}
@@ -4309,15 +4159,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
const maxRetries = 1
const lifetimeSeconds = 5
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
- secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
- }
+ secretKey := makeSecretKey(t)
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
@@ -4326,6 +4168,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
@@ -4345,6 +4188,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
SecretKey: secretKey,
},
})},
+ Clock: clock,
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4375,7 +4219,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
expectAutoGenAddrEvent(addr, newAddr)
// Simulate a DAD conflict after some time has passed.
- time.Sleep(failureTimer)
+ clock.Advance(failureTimer)
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
@@ -4390,12 +4234,13 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
// Let the next address resolve.
addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey))
expectAutoGenAddrEvent(addr, newAddr)
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
@@ -4409,6 +4254,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
//
// We expect either just the invalidation event or the deprecation event
// followed by the invalidation event.
+ clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer)
select {
case e := <-ndpDisp.autoGenAddrC:
if e.eventType == deprecatedAddr {
@@ -4421,7 +4267,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
}
} else {
@@ -4429,7 +4275,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
}
- case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for auto gen addr event")
}
}
@@ -4691,11 +4537,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
)
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
@@ -4738,17 +4582,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
),
)
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID)
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID)
}
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
@@ -4770,8 +4614,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpected router event = %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpected off-link route event = %#v", e)
default:
}
select {
@@ -4857,12 +4701,11 @@ func TestCleanupNDPState(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents),
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
}
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: true,
@@ -4874,16 +4717,17 @@ func TestCleanupNDPState(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
- expectRouterEvent := func() (bool, ndpRouterEvent) {
+ expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) {
select {
- case e := <-ndpDisp.routerC:
+ case e := <-ndpDisp.offLinkRouteC:
return true, e
default:
}
- return false, ndpRouterEvent{}
+ return false, ndpOffLinkRouteEvent{}
}
expectPrefixEvent := func() (bool, ndpPrefixEvent) {
@@ -4928,8 +4772,8 @@ func TestCleanupNDPState(t *testing.T) {
// multiple addresses.
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
@@ -4939,8 +4783,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
@@ -4950,8 +4794,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
@@ -4961,8 +4805,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
@@ -5003,14 +4847,14 @@ func TestCleanupNDPState(t *testing.T) {
test.cleanupFn(t, s)
// Collect invalidation events after having NDP state cleaned up.
- gotRouterEvents := make(map[ndpRouterEvent]int)
+ gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectRouterEvent()
+ ok, e := expectOffLinkRouteEvent()
if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
break
}
- gotRouterEvents[e]++
+ gotOffLinkRouteEvents[e]++
}
gotPrefixEvents := make(map[ndpPrefixEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
@@ -5037,14 +4881,14 @@ func TestCleanupNDPState(t *testing.T) {
t.FailNow()
}
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
}
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" {
+ t.Errorf("off-link route events mismatch (-want +got):\n%s", diff)
}
expectedPrefixEvents := map[ndpPrefixEvent]int{
{nicID: nicID1, prefix: subnet1, discovered: false}: 1,
@@ -5106,10 +4950,10 @@ func TestCleanupNDPState(t *testing.T) {
// Should not get any more events (invalidation timers should have been
// cancelled when the NDP state was cleaned up).
- time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
+ clock.Advance(lifetimeSeconds * time.Second)
select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
+ case <-ndpDisp.offLinkRouteC:
+ t.Error("unexpected off-link route event")
default:
}
select {
@@ -5134,7 +4978,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
ndpDisp := ndpDispatcher{
dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
- rememberRouter: true,
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -5236,6 +5079,23 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
expectNoDHCPv6Event()
}
+var _ rand.Source = (*savingRandSource)(nil)
+
+type savingRandSource struct {
+ s rand.Source
+
+ lastInt63 int64
+}
+
+func (d *savingRandSource) Int63() int64 {
+ i := d.s.Int63()
+ d.lastInt63 = i
+ return i
+}
+func (d *savingRandSource) Seed(seed int64) {
+ d.s.Seed(seed)
+}
+
// TestRouterSolicitation tests the initial Router Solicitations that are sent
// when a NIC newly becomes enabled.
func TestRouterSolicitation(t *testing.T) {
@@ -5402,6 +5262,9 @@ func TestRouterSolicitation(t *testing.T) {
t.Fatalf("unexpectedly got a packet = %#v", p)
}
}
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
+ }
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -5411,8 +5274,10 @@ func TestRouterSolicitation(t *testing.T) {
MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
},
})},
- Clock: clock,
+ Clock: clock,
+ RandSource: &randSource,
})
+
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -5425,19 +5290,27 @@ func TestRouterSolicitation(t *testing.T) {
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ if remaining != 0 {
+ maxRtrSolicitDelay := test.maxRtrSolicitDelay
+ if maxRtrSolicitDelay < 0 {
+ maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay
+ }
+ var actualRtrSolicitDelay time.Duration
+ if maxRtrSolicitDelay != 0 {
+ actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay
+ }
+ waitForPkt(actualRtrSolicitDelay)
remaining--
}
subTest.afterFirstRS(t, s)
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ for ; remaining != 0; remaining-- {
+ if test.effectiveRtrSolicitInt != 0 {
waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
waitForPkt(time.Nanosecond)
} else {
- waitForPkt(test.effectiveRtrSolicitInt)
+ waitForPkt(0)
}
}
@@ -5533,12 +5406,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
+ waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) {
t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
+ clock.Advance(timeout)
+ p, ok := e.Read()
if !ok {
t.Fatal("timed out waiting for packet")
}
@@ -5552,6 +5424,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPRS())
}
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -5561,6 +5434,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
MaxRtrSolicitationDelay: delay,
},
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -5568,13 +5442,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stop soliciting routers.
test.stopFn(t, s, true /* first */)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ clock.Advance(delay)
+ if _, ok := e.Read(); ok {
// A single RS may have been sent before solicitations were stopped.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok = e.ReadContext(ctx); ok {
+ clock.Advance(interval)
+ if _, ok = e.Read(); ok {
t.Fatal("should not have sent more than one RS message")
}
}
@@ -5582,9 +5454,8 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s, false /* first */)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ clock.Advance(delay)
+ if _, ok := e.Read(); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
}
@@ -5595,21 +5466,19 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Start soliciting routers.
test.startFn(t, s)
- waitForPkt(delay + defaultAsyncPositiveEventTimeout)
- waitForPkt(interval + defaultAsyncPositiveEventTimeout)
- waitForPkt(interval + defaultAsyncPositiveEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ waitForPkt(clock, delay)
+ waitForPkt(clock, interval)
+ waitForPkt(clock, interval)
+ clock.Advance(interval)
+ if _, ok := e.Read(); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
}
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
+ clock.Advance(interval)
+ if _, ok := e.Read(); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
}
})
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 509f5ce5c..08857e1a9 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -310,7 +310,7 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) {
func (n *neighborCache) init(nic *nic, r LinkAddressResolver) {
*n = neighborCache{
nic: nic,
- state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator),
+ state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.randomGenerator),
linkRes: r,
}
n.mu.Lock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 9821a18d3..7de25fe37 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -15,8 +15,6 @@
package stack
import (
- "bytes"
- "encoding/binary"
"fmt"
"math"
"math/rand"
@@ -48,9 +46,6 @@ const (
// be sent to all nodes.
testEntryBroadcastAddr = tcpip.Address("broadcast")
- // testEntryLocalAddr is the source address of neighbor probes.
- testEntryLocalAddr = tcpip.Address("local_addr")
-
// testEntryBroadcastLinkAddr is a special link address sent back to
// multicast neighbor probes.
testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast")
@@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl
randomGenerator: rng,
},
id: 1,
- stats: makeNICStats(),
+ stats: makeNICStats(tcpip.NICStats{}.FillIn()),
}, linkRes)
return linkRes
}
@@ -106,20 +101,24 @@ type testEntryStore struct {
entriesMap map[tcpip.Address]NeighborEntry
}
-func toAddress(i int) tcpip.Address {
- buf := new(bytes.Buffer)
- binary.Write(buf, binary.BigEndian, uint8(1))
- binary.Write(buf, binary.BigEndian, uint8(0))
- binary.Write(buf, binary.BigEndian, uint16(i))
- return tcpip.Address(buf.String())
+func toAddress(i uint16) tcpip.Address {
+ return tcpip.Address([]byte{
+ 1,
+ 0,
+ byte(i >> 8),
+ byte(i),
+ })
}
-func toLinkAddress(i int) tcpip.LinkAddress {
- buf := new(bytes.Buffer)
- binary.Write(buf, binary.BigEndian, uint8(1))
- binary.Write(buf, binary.BigEndian, uint8(0))
- binary.Write(buf, binary.BigEndian, uint32(i))
- return tcpip.LinkAddress(buf.String())
+func toLinkAddress(i uint16) tcpip.LinkAddress {
+ return tcpip.LinkAddress([]byte{
+ 1,
+ 0,
+ 0,
+ 0,
+ byte(i >> 8),
+ byte(i),
+ })
}
// newTestEntryStore returns a testEntryStore pre-populated with entries.
@@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore {
store := &testEntryStore{
entriesMap: make(map[tcpip.Address]NeighborEntry),
}
- for i := 0; i < entryStoreSize; i++ {
+ for i := uint16(0); i < entryStoreSize; i++ {
addr := toAddress(i)
linkAddr := toLinkAddress(i)
@@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore {
}
// size returns the number of entries in the store.
-func (s *testEntryStore) size() int {
+func (s *testEntryStore) size() uint16 {
s.mu.RLock()
defer s.mu.RUnlock()
- return len(s.entriesMap)
+ return uint16(len(s.entriesMap))
}
// entry returns the entry at index i. Returns an empty entry and false if i is
// out of bounds.
-func (s *testEntryStore) entry(i int) (NeighborEntry, bool) {
+func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) {
return s.entryByAddr(toAddress(i))
}
@@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry {
entries := make([]NeighborEntry, 0, len(s.entriesMap))
s.mu.RLock()
defer s.mu.RUnlock()
- for i := 0; i < entryStoreSize; i++ {
+ for i := uint16(0); i < entryStoreSize; i++ {
addr := toAddress(i)
if entry, ok := s.entriesMap[addr]; ok {
entries = append(entries, entry)
@@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry {
}
// set modifies the link addresses of an entry.
-func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) {
+func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) {
addr := toAddress(i)
s.mu.Lock()
defer s.mu.Unlock()
@@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return 0
}
-type entryEvent struct {
- nicID tcpip.NICID
- address tcpip.Address
- linkAddr tcpip.LinkAddress
- state NeighborState
-}
-
func TestNeighborCacheGetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
@@ -301,10 +293,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
})
}
@@ -313,10 +305,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
})
@@ -347,10 +339,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -419,10 +411,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext {
}
type overflowOptions struct {
- startAtEntryIndex int
+ startAtEntryIndex uint16
wantStaticEntries []NeighborEntry
}
@@ -500,12 +492,12 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i)
}
- durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now().Add(-durationReachableNanos),
}
wantUnorderedEntries = append(wantUnorderedEntries, wantEntry)
}
@@ -571,10 +563,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -616,10 +608,10 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -663,10 +655,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -689,10 +681,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -733,10 +725,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -758,10 +750,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -814,20 +806,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now(),
},
},
{
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -844,10 +836,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -875,10 +867,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
}
if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -890,10 +882,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -910,10 +902,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -947,10 +939,10 @@ func TestNeighborCacheClear(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -973,20 +965,20 @@ func TestNeighborCacheClear(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
{
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1027,10 +1019,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
EventType: entryTestRemoved,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: c.clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: c.clock.Now(),
},
},
}
@@ -1062,13 +1054,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
clock := faketime.NewManualClock()
linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- startedAt := clock.NowNanoseconds()
+ startedAt := clock.Now()
// The following logic is very similar to overflowCache, but
// periodically refreshes the frequently used entry.
// Fill the neighbor cache to capacity
- for i := 0; i < neighborCacheSize; i++ {
+ for i := uint16(0); i < neighborCacheSize; i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
@@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
}
// Keep adding more entries
- for i := neighborCacheSize; i < linkRes.entries.size(); i++ {
+ for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil {
@@ -1118,7 +1110,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
State: Reachable,
// Can be inferred since the frequently used entry is the first to
// be created and transitioned to Reachable.
- UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(),
+ UpdatedAt: startedAt.Add(typicalLatency),
},
}
@@ -1127,12 +1119,12 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
- durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency
wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now().Add(-durationReachableNanos),
})
}
@@ -1190,12 +1182,12 @@ func TestNeighborCacheConcurrent(t *testing.T) {
if !ok {
t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i)
}
- durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds()
+ durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency
wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now().Add(-durationReachableNanos),
})
}
@@ -1244,10 +1236,10 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -1263,10 +1255,10 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(want, e); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -1301,10 +1293,10 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
@@ -1405,10 +1397,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestAdded,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1436,10 +1428,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1455,10 +1447,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
{
wantEntries := []NeighborEntry{
{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
}
if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" {
@@ -1488,10 +1480,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1518,10 +1510,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
EventType: entryTestChanged,
NICID: 1,
Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1541,10 +1533,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
}
if diff := cmp.Diff(gotEntry, wantEntry); diff != "" {
t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff)
@@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) {
linkRes.delay = 0
// Clear for every possible size of the cache
- for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ {
+ for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ {
// Fill the neighbor cache to capacity.
- for i := 0; i < cacheSize; i++ {
+ for i := uint16(0); i < cacheSize; i++ {
entry, ok := linkRes.entries.entry(i)
if !ok {
b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 6d95e1664..0a59eecdd 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -31,10 +31,10 @@ const (
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
- Addr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAtNanos int64
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
@@ -138,10 +138,10 @@ func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *
// calling `setStateLocked`.
func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
entry := NeighborEntry{
- Addr: addr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(),
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: cache.nic.stack.clock.Now(),
}
n := &neighborEntry{
cache: cache,
@@ -166,14 +166,20 @@ func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) {
if ch := e.mu.done; ch != nil {
close(ch)
e.mu.done = nil
- // Dequeue the pending packets in a new goroutine to not hold up the current
+ // Dequeue the pending packets asynchronously to not hold up the current
// goroutine as writing packets may be a costly operation.
//
// At the time of writing, when writing packets, a neighbor's link address
// is resolved (which ends up obtaining the entry's lock) while holding the
- // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
- // a lock ordering violation.
- go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err)
+ // link resolution queue's lock. Dequeuing packets asynchronously avoids a
+ // lock ordering violation.
+ //
+ // NB: this is equivalent to spawning a goroutine directly using the go
+ // keyword but allows tests that use manual clocks to deterministically
+ // wait for this work to complete.
+ e.cache.nic.stack.clock.AfterFunc(0, func() {
+ e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err)
+ })
}
}
@@ -224,7 +230,7 @@ func (e *neighborEntry) cancelTimerLocked() {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) removeLocked() {
- e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now()
e.dispatchRemoveEventLocked()
e.cancelTimerLocked()
// TODO(https://gvisor.dev/issues/5583): test the case where this function is
@@ -246,7 +252,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
prev := e.mu.neigh.State
e.mu.neigh.State = next
- e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now()
config := e.nudState.Config()
switch next {
@@ -307,7 +313,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// a shared lock.
e.mu.timer = timer{
done: &done,
- timer: e.cache.nic.stack.Clock().AfterFunc(0, func() {
+ timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr)
@@ -354,14 +360,14 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
case Unknown, Unreachable:
prev := e.mu.neigh.State
e.mu.neigh.State = Incomplete
- e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now()
switch prev {
case Unknown:
e.dispatchAddEventLocked()
case Unreachable:
e.dispatchChangeEventLocked()
- e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment()
+ e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment()
}
config := e.nudState.Config()
@@ -378,7 +384,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// a shared lock.
e.mu.timer = timer{
done: &done,
- timer: e.cache.nic.stack.Clock().AfterFunc(0, func() {
+ timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() {
var err tcpip.Error = &tcpip.ErrTimeout{}
if remaining != 0 {
// As per RFC 4861 section 7.2.2:
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 1d39ee73d..59d86d6d4 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -36,11 +36,6 @@ const (
entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
-
- // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
- // except where another value is explicitly used. It is chosen to match the
- // MTU of loopback interfaces on Linux systems.
- entryTestNetDefaultMTU = 65536
)
var (
@@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A
// ResolveStaticAddress attempts to resolve address without sending requests.
// It either resolves the name immediately or returns the empty LinkAddress.
-func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) {
return "", false
}
// LinkAddressProtocol returns the network protocol of the addresses this
// resolver can resolve.
-func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return entryTestNetNumber
}
@@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
nudConfigs: c,
randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
},
- stats: makeNICStats(),
+ stats: makeNICStats(tcpip.NICStats{}.FillIn()),
}
netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil)
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
@@ -354,10 +349,10 @@ func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *
EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -415,10 +410,10 @@ func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entry
EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -446,7 +441,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
// UpdatedAt should remain the same during address resolution.
e.mu.Lock()
- startedAt := e.mu.neigh.UpdatedAtNanos
+ startedAt := e.mu.neigh.UpdatedAt
e.mu.Unlock()
// Wait for the rest of the reachability probe transmissions, signifying
@@ -470,7 +465,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.mu.neigh.UpdatedAtNanos, startedAt; got != want {
+ if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want {
t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -485,10 +480,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -547,10 +542,10 @@ func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -644,10 +639,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -678,10 +673,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -757,10 +752,10 @@ func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *tes
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -943,10 +938,10 @@ func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDis
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -998,10 +993,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1050,10 +1045,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1102,10 +1097,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1191,10 +1186,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1243,10 +1238,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1284,10 +1279,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1332,10 +1327,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1391,10 +1386,10 @@ func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTe
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1443,10 +1438,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1498,10 +1493,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1553,10 +1548,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1645,10 +1640,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1697,10 +1692,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1770,10 +1765,10 @@ func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatc
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1827,10 +1822,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -1882,10 +1877,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -2003,10 +1998,10 @@ func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, lin
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: linkAddr,
- State: Reachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: linkAddr,
+ State: Reachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -2155,10 +2150,10 @@ func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDD
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Unreachable,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Unreachable,
+ UpdatedAt: clock.Now(),
},
},
}
@@ -2227,10 +2222,10 @@ func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkR
EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- UpdatedAtNanos: clock.NowNanoseconds(),
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ UpdatedAt: clock.Now(),
},
},
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index dbba2c79f..b854d868c 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -51,7 +51,7 @@ type nic struct {
name string
context NICContext
- stats NICStats
+ stats sharedStats
// The network endpoints themselves may be modified by calling the interface's
// methods, but the map reference and entries must be constant.
@@ -78,26 +78,13 @@ type nic struct {
}
}
-// NICStats hold statistics for a NIC.
-type NICStats struct {
- Tx DirectionStats
- Rx DirectionStats
-
- DisabledRx DirectionStats
-
- Neighbor NeighborStats
-}
-
-func makeNICStats() NICStats {
- var s NICStats
- tcpip.InitStatCounters(reflect.ValueOf(&s).Elem())
- return s
-}
-
-// DirectionStats includes packet and byte counts.
-type DirectionStats struct {
- Packets *tcpip.StatCounter
- Bytes *tcpip.StatCounter
+// makeNICStats initializes the NIC statistics and associates them to the global
+// NIC statistics.
+func makeNICStats(global tcpip.NICStats) sharedStats {
+ var stats sharedStats
+ tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem())
+ stats.init(&stats.local, &global)
+ return stats
}
type packetEndpointList struct {
@@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
id: id,
name: name,
context: ctx,
- stats: makeNICStats(),
+ stats: makeNICStats(stack.Stats().NICs),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver),
duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
@@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt
return err
}
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ n.stats.tx.packets.Increment()
+ n.stats.tx.bytes.IncrementBy(uint64(numBytes))
return nil
}
@@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk
}
writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
- n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
+ n.stats.tx.packets.IncrementBy(uint64(writtenPackets))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
writtenBytes += pb.Size()
}
- n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ n.stats.tx.bytes.IncrementBy(uint64(writtenBytes))
return writtenPackets, err
}
@@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
if !enabled {
n.mu.RUnlock()
- n.stats.DisabledRx.Packets.Increment()
- n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
+ n.stats.disabledRx.packets.Increment()
+ n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size()))
return
}
- n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
+ n.stats.rx.packets.Increment()
+ n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size()))
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
n.mu.RUnlock()
- n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ n.stats.unknownL3ProtocolRcvdPackets.Increment()
return
}
@@ -786,41 +773,35 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
- n.stack.stats.UnknownProtocolRcvdPackets.Increment()
+ n.stats.unknownL4ProtocolRcvdPackets.Increment()
return TransportPacketProtocolUnreachable
}
transProto := state.proto
- // Raw socket packets are delivered based solely on the transport
- // protocol number. We do not inspect the payload to ensure it's
- // validly formed.
- n.stack.demux.deliverRawPacket(protocol, pkt)
-
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
if pkt.TransportHeader().View().IsEmpty() {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
// ICMP packets may be longer, but until icmp.Parse is implemented, here
// we parse it using the minimum size.
if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
// We consider a malformed transport packet handled because there is
// nothing the caller can do.
return TransportPacketHandled
}
} else if !transProto.Parse(pkt) {
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
return TransportPacketHandled
}
}
srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
if err != nil {
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
return TransportPacketHandled
}
@@ -852,7 +833,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
// If it doesn't handle it then we should do so.
switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res {
case UnknownDestinationPacketMalformed:
- n.stack.stats.MalformedRcvdPackets.Increment()
+ n.stats.malformedL4RcvdPackets.Increment()
return TransportPacketHandled
case UnknownDestinationPacketUnhandled:
return TransportPacketDestinationPortUnreachable
@@ -891,6 +872,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo
}
}
+// DeliverRawPacket implements TransportDispatcher.
+func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
+ // For ICMPv4 only we validate the header length for compatibility with
+ // raw(7) ICMP_FILTER. The same check is made in Linux here:
+ // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189.
+ if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize {
+ return
+ }
+ n.stack.demux.deliverRawPacket(protocol, pkt)
+}
+
// ID implements NetworkInterface.
func (n *nic) ID() tcpip.NICID {
return n.id
diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go
new file mode 100644
index 000000000..1773d5e8d
--- /dev/null
+++ b/pkg/tcpip/stack/nic_stats.go
@@ -0,0 +1,74 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+type sharedStats struct {
+ local tcpip.NICStats
+ multiCounterNICStats
+}
+
+// LINT.IfChange(multiCounterNICPacketStats)
+
+type multiCounterNICPacketStats struct {
+ packets tcpip.MultiCounterStat
+ bytes tcpip.MultiCounterStat
+}
+
+func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) {
+ m.packets.Init(a.Packets, b.Packets)
+ m.bytes.Init(a.Bytes, b.Bytes)
+}
+
+// LINT.ThenChange(../../tcpip.go:NICPacketStats)
+
+// LINT.IfChange(multiCounterNICNeighborStats)
+
+type multiCounterNICNeighborStats struct {
+ unreachableEntryLookups tcpip.MultiCounterStat
+}
+
+func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) {
+ m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups)
+}
+
+// LINT.ThenChange(../../tcpip.go:NICNeighborStats)
+
+// LINT.IfChange(multiCounterNICStats)
+
+type multiCounterNICStats struct {
+ unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat
+ unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat
+ malformedL4RcvdPackets tcpip.MultiCounterStat
+ tx multiCounterNICPacketStats
+ rx multiCounterNICPacketStats
+ disabledRx multiCounterNICPacketStats
+ neighbor multiCounterNICNeighborStats
+}
+
+func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) {
+ m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets)
+ m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets)
+ m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets)
+ m.tx.init(&a.Tx, &b.Tx)
+ m.rx.init(&a.Rx, &b.Rx)
+ m.disabledRx.init(&a.DisabledRx, &b.DisabledRx)
+ m.neighbor.init(&a.Neighbor, &b.Neighbor)
+}
+
+// LINT.ThenChange(../../tcpip.go:NICStats)
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 8a3005295..5cb342f78 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,11 +15,13 @@
package stack
import (
+ "reflect"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ AddressableEndpoint = (*testIPv6Endpoint)(nil)
@@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
nic := nic{
- stats: makeNICStats(),
+ stats: makeNICStats(tcpip.NICStats{}.FillIn()),
}
- if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
+ if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 {
t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
}
- if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
+ if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 {
t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
}
- if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Packets.Value(); got != 0 {
t.Errorf("got Rx.Packets = %d, want = 0", got)
}
- if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Bytes.Value(); got != 0 {
t.Errorf("got Rx.Bytes = %d, want = 0", got)
}
@@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(),
}))
- if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
+ if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
}
- if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
+ if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 {
t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
}
- if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Packets.Value(); got != 0 {
t.Errorf("got Rx.Packets = %d, want = 0", got)
}
- if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ if got := nic.stats.local.Rx.Bytes.Value(); got != 0 {
t.Errorf("got Rx.Bytes = %d, want = 0", got)
}
}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ global := tcpip.NICStats{}.FillIn()
+ nic := nic{
+ stats: makeNICStats(global),
+ }
+ multi := nic.stats.multiCounterNICStats
+ local := nic.stats.local
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index 5a94e9ac6..ca9822bca 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -16,6 +16,7 @@ package stack
import (
"math"
+ "math/rand"
"sync"
"time"
@@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 {
return defaultMaxRandomFactor
}
-// A Rand is a source of random numbers.
-type Rand interface {
- // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0).
- Float32() float32
-}
-
// NUDState stores states needed for calculating reachable time.
type NUDState struct {
- rng Rand
+ clock tcpip.Clock
+ rng *rand.Rand
- // mu protects the fields below.
- //
- // It is necessary for NUDState to handle its own locking since neighbor
- // entries may access the NUD state from within the goroutine spawned by
- // time.AfterFunc(). This goroutine may run concurrently with the main
- // process for controlling the neighbor cache and would otherwise introduce
- // race conditions if NUDState was not locked properly.
- mu sync.RWMutex
+ mu struct {
+ sync.RWMutex
- config NUDConfigurations
+ config NUDConfigurations
- // reachableTime is the duration to wait for a REACHABLE entry to
- // transition into STALE after inactivity. This value is calculated with
- // the algorithm defined in RFC 4861 section 6.3.2.
- reachableTime time.Duration
+ // reachableTime is the duration to wait for a REACHABLE entry to
+ // transition into STALE after inactivity. This value is calculated with
+ // the algorithm defined in RFC 4861 section 6.3.2.
+ reachableTime time.Duration
- expiration time.Time
- prevBaseReachableTime time.Duration
- prevMinRandomFactor float32
- prevMaxRandomFactor float32
+ expiration time.Time
+ prevBaseReachableTime time.Duration
+ prevMinRandomFactor float32
+ prevMaxRandomFactor float32
+ }
}
// NewNUDState returns new NUDState using c as configuration and the specified
// random number generator for use in recomputing ReachableTime.
-func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
+func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState {
s := &NUDState{
- rng: rng,
+ clock: clock,
+ rng: rng,
}
- s.config = c
+ s.mu.config = c
return s
}
@@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
func (s *NUDState) Config() NUDConfigurations {
s.mu.RLock()
defer s.mu.RUnlock()
- return s.config
+ return s.mu.config
}
// SetConfig replaces the existing NUD configurations with c.
func (s *NUDState) SetConfig(c NUDConfigurations) {
s.mu.Lock()
defer s.mu.Unlock()
- s.config = c
+ s.mu.config = c
}
// ReachableTime returns the duration to wait for a REACHABLE entry to
@@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration {
s.mu.Lock()
defer s.mu.Unlock()
- if time.Now().After(s.expiration) ||
- s.config.BaseReachableTime != s.prevBaseReachableTime ||
- s.config.MinRandomFactor != s.prevMinRandomFactor ||
- s.config.MaxRandomFactor != s.prevMaxRandomFactor {
+ if s.clock.Now().After(s.mu.expiration) ||
+ s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime ||
+ s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor ||
+ s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor {
s.recomputeReachableTimeLocked()
}
- return s.reachableTime
+ return s.mu.reachableTime
}
// recomputeReachableTimeLocked forces a recalculation of ReachableTime using
@@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration {
//
// s.mu MUST be locked for writing.
func (s *NUDState) recomputeReachableTimeLocked() {
- s.prevBaseReachableTime = s.config.BaseReachableTime
- s.prevMinRandomFactor = s.config.MinRandomFactor
- s.prevMaxRandomFactor = s.config.MaxRandomFactor
+ s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime
+ s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor
+ s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor
- randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor)
+ randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor)
// Check for overflow, given that minRandomFactor and maxRandomFactor are
// guaranteed to be positive numbers.
- if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) {
- s.reachableTime = time.Duration(math.MaxInt64)
+ if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) {
+ s.mu.reachableTime = time.Duration(math.MaxInt64)
} else if randomFactor == 1 {
// Avoid loss of precision when a large base reachable time is used.
- s.reachableTime = s.config.BaseReachableTime
+ s.mu.reachableTime = s.mu.config.BaseReachableTime
} else {
- reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor)
- s.reachableTime = time.Duration(reachableTime)
+ reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor)
+ s.mu.reachableTime = time.Duration(reachableTime)
}
- s.expiration = time.Now().Add(2 * time.Hour)
+ s.mu.expiration = s.clock.Now().Add(2 * time.Hour)
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index e1253f310..1aeb2f8a5 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -16,6 +16,7 @@ package stack_test
import (
"math"
+ "math/rand"
"testing"
"time"
@@ -28,17 +29,15 @@ import (
)
const (
- defaultBaseReachableTime = 30 * time.Second
- minimumBaseReachableTime = time.Millisecond
- defaultMinRandomFactor = 0.5
- defaultMaxRandomFactor = 1.5
- defaultRetransmitTimer = time.Second
- minimumRetransmitTimer = time.Millisecond
- defaultDelayFirstProbeTime = 5 * time.Second
- defaultMaxMulticastProbes = 3
- defaultMaxUnicastProbes = 3
- defaultMaxAnycastDelayTime = time.Second
- defaultMaxReachbilityConfirmations = 3
+ defaultBaseReachableTime = 30 * time.Second
+ minimumBaseReachableTime = time.Millisecond
+ defaultMinRandomFactor = 0.5
+ defaultMaxRandomFactor = 1.5
+ defaultRetransmitTimer = time.Second
+ minimumRetransmitTimer = time.Millisecond
+ defaultDelayFirstProbeTime = 5 * time.Second
+ defaultMaxMulticastProbes = 3
+ defaultMaxUnicastProbes = 3
defaultFakeRandomNum = 0.5
)
@@ -48,12 +47,14 @@ type fakeRand struct {
num float32
}
-var _ stack.Rand = (*fakeRand)(nil)
+var _ rand.Source = (*fakeRand)(nil)
-func (f *fakeRand) Float32() float32 {
- return f.num
+func (f *fakeRand) Int63() int64 {
+ return int64(f.num * float32(1<<63))
}
+func (*fakeRand) Seed(int64) {}
+
func TestNUDFunctions(t *testing.T) {
const nicID = 1
@@ -169,7 +170,7 @@ func TestNUDFunctions(t *testing.T) {
t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
} else if test.expectedErr == nil {
if diff := cmp.Diff(
- []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}},
+ []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}},
neighbors,
); diff != "" {
t.Errorf("neighbors mismatch (-want +got):\n%s", diff)
@@ -710,7 +711,8 @@ func TestNUDStateReachableTime(t *testing.T) {
rng := fakeRand{
num: defaultFakeRandomNum,
}
- s := stack.NewNUDState(c, &rng)
+ var clock faketime.NullClock
+ s := stack.NewNUDState(c, &clock, rand.New(&rng))
if got, want := s.ReachableTime(), test.want; got != want {
t.Errorf("got ReachableTime = %q, want = %q", got, want)
}
@@ -782,7 +784,8 @@ func TestNUDStateRecomputeReachableTime(t *testing.T) {
rng := fakeRand{
num: defaultFakeRandomNum,
}
- s := stack.NewNUDState(c, &rng)
+ var clock faketime.NullClock
+ s := stack.NewNUDState(c, &clock, rand.New(&rng))
old := s.ReachableTime()
if got, want := s.ReachableTime(), old; got != want {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 01652fbe7..9192d8433 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -134,7 +134,7 @@ type PacketBuffer struct {
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
- // NICID is the ID of the interface the network packet was received at.
+ // NICID is the ID of the last interface the network packet was handled at.
NICID tcpip.NICID
// RXTransportChecksumValidated indicates that transport checksum verification
@@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int {
func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View {
h := &pk.headers[typ]
if h.length > 0 {
- panic(fmt.Sprintf("push must not be called twice: type %s", typ))
+ panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size))
}
if pk.pushed+size > pk.reserved {
- panic("not enough headroom reserved")
+ panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved))
}
pk.pushed += size
h.offset = -pk.pushed
diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go
index 421fb5c15..c8294eb6e 100644
--- a/pkg/tcpip/stack/rand.go
+++ b/pkg/tcpip/stack/rand.go
@@ -15,7 +15,7 @@
package stack
import (
- mathrand "math/rand"
+ "math/rand"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -23,7 +23,7 @@ import (
// lockedRandomSource provides a threadsafe rand.Source.
type lockedRandomSource struct {
mu sync.Mutex
- src mathrand.Source
+ src rand.Source
}
func (r *lockedRandomSource) Int63() (n int64) {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 85bb87b4b..dfe2c886f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -265,6 +265,11 @@ type TransportDispatcher interface {
//
// DeliverTransportError takes ownership of the packet buffer.
DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer)
+
+ // DeliverRawPacket delivers a packet to any subscribed raw sockets.
+ //
+ // DeliverRawPacket does NOT take ownership of the packet buffer.
+ DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -420,7 +425,7 @@ const (
PermanentExpired
// Temporary is an endpoint, created on a one-off basis to temporarily
- // consider the NIC bound an an address that it is not explictiy bound to
+ // consider the NIC bound an an address that it is not explicitly bound to
// (such as a permanent address). Its reference count must not be biased by 1
// so that the address is removed immediately when references to it are no
// longer held.
@@ -630,7 +635,7 @@ type NetworkEndpoint interface {
// HandlePacket takes ownership of pkt.
HandlePacket(pkt *PacketBuffer)
- // Close is called when the endpoint is reomved from a stack.
+ // Close is called when the endpoint is removed from a stack.
Close()
// NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for
@@ -968,7 +973,7 @@ type DuplicateAddressDetector interface {
// called with the result of the original DAD request.
CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition
- // SetDADConfiguations sets the configurations for DAD.
+ // SetDADConfigurations sets the configurations for DAD.
SetDADConfigurations(c DADConfigurations)
// DuplicateAddressProtocol returns the network protocol the receiver can
@@ -979,7 +984,7 @@ type DuplicateAddressDetector interface {
// LinkAddressResolver handles link address resolution for a network protocol.
type LinkAddressResolver interface {
// LinkAddressRequest sends a request for the link address of the target
- // address. The request is broadcasted on the local network if a remote link
+ // address. The request is broadcast on the local network if a remote link
// address is not provided.
LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error
@@ -1072,4 +1077,4 @@ type GSOEndpoint interface {
// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
// This isn't a hard limit, because it is never set into packet headers.
-const SoftwareGSOMaxSize = (1 << 16)
+const SoftwareGSOMaxSize = 1 << 16
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 8814f45a6..81fabe29a 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,17 +20,16 @@
package stack
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
- mathrand "math/rand"
+ "math/rand"
"sync/atomic"
"time"
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/atomicbitops"
- "gvisor.dev/gvisor/pkg/rand"
+ cryptorand "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -40,13 +39,6 @@ import (
)
const (
- // ageLimit is set to the same cache stale time used in Linux.
- ageLimit = 1 * time.Minute
- // resolutionTimeout is set to the same ARP timeout used in Linux.
- resolutionTimeout = 1 * time.Second
- // resolutionAttempts is set to the same ARP retries used in Linux.
- resolutionAttempts = 3
-
// DefaultTOS is the default type of service value for network endpoints.
DefaultTOS = 0
)
@@ -116,7 +108,7 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
- // TODO(gvisor.dev/issue/170): S/R this field.
+ // TODO(gvisor.dev/issue/4595): S/R this field.
tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -145,7 +137,7 @@ type Stack struct {
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
- randomGenerator *mathrand.Rand
+ randomGenerator *rand.Rand
// secureRNG is a cryptographically secure random number generator.
secureRNG io.Reader
@@ -196,9 +188,9 @@ type Options struct {
// TransportProtocols lists the transport protocols to enable.
TransportProtocols []TransportProtocolFactory
- // Clock is an optional clock source used for timestampping packets.
+ // Clock is an optional clock used for timekeeping.
//
- // If no Clock is specified, the clock source will be time.Now.
+ // If Clock is nil, tcpip.NewStdClock() will be used.
Clock tcpip.Clock
// Stats are optional statistic counters.
@@ -225,15 +217,21 @@ type Options struct {
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
- // returned by rand.Read().
+ // returned by the stack secure RNG.
//
// RandSource must be thread-safe.
- RandSource mathrand.Source
+ RandSource rand.Source
- // IPTables are the initial iptables rules. If nil, iptables will allow
+ // IPTables are the initial iptables rules. If nil, DefaultIPTables will be
+ // used to construct the initial iptables rules.
// all traffic.
IPTables *IPTables
+ // DefaultIPTables is an optional iptables rules constructor that is called
+ // if IPTables is nil. If both fields are nil, iptables will allow all
+ // traffic.
+ DefaultIPTables func(uint32) *IPTables
+
// SecureRNG is a cryptographically secure random number generator.
SecureRNG io.Reader
}
@@ -331,23 +329,32 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ if opts.SecureRNG == nil {
+ opts.SecureRNG = cryptorand.Reader
+ }
+
randSrc := opts.RandSource
if randSrc == nil {
- // Source provided by mathrand.NewSource is not thread-safe so
+ var v int64
+ if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
+ // Source provided by rand.NewSource is not thread-safe so
// we wrap it in a simple thread-safe version.
- randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
+ randSrc = &lockedRandomSource{src: rand.NewSource(v)}
}
+ randomGenerator := rand.New(randSrc)
+ seed := randomGenerator.Uint32()
if opts.IPTables == nil {
- opts.IPTables = DefaultTables()
+ if opts.DefaultIPTables == nil {
+ opts.DefaultIPTables = DefaultTables
+ }
+ opts.IPTables = opts.DefaultIPTables(seed)
}
opts.NUDConfigs.resetInvalidFields()
- if opts.SecureRNG == nil {
- opts.SecureRNG = rand.Reader
- }
-
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -360,11 +367,11 @@ func New(opts Options) *Stack {
handleLocal: opts.HandleLocal,
tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
+ seed: seed,
nudConfigs: opts.NUDConfigs,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- randomGenerator: mathrand.New(randSrc),
+ randomGenerator: randomGenerator,
secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
@@ -804,7 +811,7 @@ type NICInfo struct {
// MTU is the maximum transmission unit.
MTU uint32
- Stats NICStats
+ Stats tcpip.NICStats
// NetworkStats holds the stats of each NetworkEndpoint bound to the NIC.
NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats
@@ -856,7 +863,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
MTU: nic.LinkEndpoint.MTU(),
- Stats: nic.stats,
+ Stats: nic.stats.local,
NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
@@ -1819,7 +1826,7 @@ func (s *Stack) Seed() uint32 {
// Rand returns a reference to a pseudo random generator that can be used
// to generate random numbers as required.
-func (s *Stack) Rand() *mathrand.Rand {
+func (s *Stack) Rand() *rand.Rand {
return s.randomGenerator
}
@@ -1829,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader {
return s.secureRNG
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
-func generateRandInt64() int64 {
- b := make([]byte, 8)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- buf := bytes.NewReader(b)
- var v int64
- if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
- panic(err)
- }
- return v
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
@@ -1886,9 +1872,8 @@ const (
// ParsePacketBufferTransport parses the provided packet buffer's transport
// header.
func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
return ParsedOK
}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
index 33824afd0..dfec4258a 100644
--- a/pkg/tcpip/stack/stack_global_state.go
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -14,78 +14,6 @@
package stack
-import "time"
-
// StackFromEnv is the global stack created in restore run.
// FIXME(b/36201077)
var StackFromEnv *Stack
-
-// saveT is invoked by stateify.
-func (t *TCPCubicState) saveT() unixTime {
- return unixTime{t.T.Unix(), t.T.UnixNano()}
-}
-
-// loadT is invoked by stateify.
-func (t *TCPCubicState) loadT(unix unixTime) {
- t.T = time.Unix(unix.second, unix.nano)
-}
-
-// saveXmitTime is invoked by stateify.
-func (t *TCPRACKState) saveXmitTime() unixTime {
- return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()}
-}
-
-// loadXmitTime is invoked by stateify.
-func (t *TCPRACKState) loadXmitTime(unix unixTime) {
- t.XmitTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveLastSendTime is invoked by stateify.
-func (t *TCPSenderState) saveLastSendTime() unixTime {
- return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()}
-}
-
-// loadLastSendTime is invoked by stateify.
-func (t *TCPSenderState) loadLastSendTime(unix unixTime) {
- t.LastSendTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRTTMeasureTime is invoked by stateify.
-func (t *TCPSenderState) saveRTTMeasureTime() unixTime {
- return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()}
-}
-
-// loadRTTMeasureTime is invoked by stateify.
-func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) {
- t.RTTMeasureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime {
- return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()}
-}
-
-// loadMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
- r.MeasureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRTTMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime {
- return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()}
-}
-
-// loadRTTMeasureTime is invoked by stateify.
-func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) {
- r.RTTMeasureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveSegTime is invoked by stateify.
-func (t *TCPEndpointState) saveSegTime() unixTime {
- return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()}
-}
-
-// loadSegTime is invoked by stateify.
-func (t *TCPEndpointState) loadSegTime(unix unixTime) {
- t.SegTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 02d54d29b..21951d05a 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -166,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
-func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
- return 0
-}
-
func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -197,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
+func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -463,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer
}
}
-func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
}
}
-func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
@@ -920,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) {
if err := test.downFn(s, nicID1); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
}
- testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{})
testSend(t, r2, ep2, buf)
// Writes with Routes that use NIC2 after being brought down should fail.
if err := test.downFn(s, nicID2); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
}
- testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
- testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{})
if upFn := test.upFn; upFn != nil {
// Writes with Routes that use NIC1 after being brought up should
@@ -941,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{})
}
})
}
@@ -1066,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
err := s.RemoveAddress(1, localAddr)
@@ -1118,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
{
@@ -1140,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A
// No address given, verify that there is no address assigned to the NIC.
for _, a := range info.ProtocolAddresses {
if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
- t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{})
}
}
return
@@ -1220,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
// 2. Add Address, everything should work.
@@ -1248,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
// 4. Add Address back, everything should work again.
@@ -1287,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) {
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
// 7. Add Address back, everything should work again.
@@ -1324,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
}
})
}
@@ -1574,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
// Sending a packet fails.
- testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{})
+ testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{})
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
@@ -1615,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
}
- protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
+ protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
if err := s.AddProtocolAddress(1, protoAddr); err != nil {
t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
}
@@ -1641,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
- defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
+ defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any}
// Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
- nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
+ nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24}
nic1Gateway := testutil.MustParse4("192.168.1.1")
// Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
- nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
+ nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24}
nic2Gateway := testutil.MustParse4("10.10.10.1")
// Create a new stack with two NICs.
@@ -1660,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err := s.CreateNIC(2, ep); err != nil {
t.Fatalf("CreateNIC failed: %s", err)
}
- nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
+ nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
}
- nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
+ nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
}
@@ -1709,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
},
rt...,
)
@@ -2049,7 +2045,7 @@ func TestAddAddress(t *testing.T) {
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
})
}
@@ -2113,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) {
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
})
}
}
@@ -2234,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
s := stack.New(stack.Options{})
- ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00")
for _, call := range test.calls {
if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
@@ -2248,46 +2244,87 @@ func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed: ", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+
+ nics := []struct {
+ addr tcpip.Address
+ txByteCount int
+ rxByteCount int
+ }{
+ {
+ addr: "\x01",
+ txByteCount: 30,
+ rxByteCount: 10,
+ },
+ {
+ addr: "\x02",
+ txByteCount: 50,
+ rxByteCount: 20,
+ },
}
- // Route all packets for address \x01 to NIC 1.
- {
- subnet, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
+
+ var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int
+ for i, nic := range nics {
+ nicid := tcpip.NICID(i)
+ ep := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
+ t.Fatal("CreateNIC failed: ", err)
+ }
+ if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
+ t.Fatal("AddAddress failed:", err)
}
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
- // Send a packet to address 1.
- buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
- }
+ {
+ subnet, err := tcpip.NewSubnet(nic.addr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}})
+ }
+
+ nicStats := s.NICInfo()[nicid].Stats
- if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
+ // Inbound packet.
+ rxBuffer := buffer.NewView(nic.rxByteCount)
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: rxBuffer.ToVectorisedView(),
+ }))
+ if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
+ }
+ if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want {
+ t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ rxPacketsTotal++
+ rxBytesTotal += nic.rxByteCount
+
+ // Outbound packet.
+ txBuffer := buffer.NewView(nic.txByteCount)
+ actualTxLength := nic.txByteCount + fakeNetHeaderLen
+ if err := sendTo(s, nic.addr, txBuffer); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
+ want := ep.Drain()
+ if got := nicStats.Tx.Packets.Value(); got != uint64(want) {
+ t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want)
+ }
+ if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ txPacketsTotal += want
+ txBytesTotal += actualTxLength
}
- payload := buffer.NewView(10)
- // Write a packet out via the address for NIC 1
- if err := sendTo(s, "\x01", payload); err != nil {
- t.Fatal("sendTo failed: ", err)
+ // Now verify that each NIC stats was correctly aggregated at the stack level.
+ if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want {
+ t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want)
}
- want := uint64(ep1.Drain())
- if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
+ if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want {
+ t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want)
}
-
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
+ if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want)
+ }
+ if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -2316,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{})
id := tcpip.NICID(1)
- ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00")
if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
}
@@ -2603,15 +2640,17 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
const nicID = 1
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
dadConfigs := stack.DefaultDADConfigurations()
+ clock := faketime.NewManualClock()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: true,
NDPDisp: &ndpDisp,
DADConfigs: dadConfigs,
})},
+ Clock: clock,
}
e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
@@ -2629,17 +2668,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
linkLocalAddr := header.LinkLocalAddr(linkAddr1)
// Wait for DAD to resolve.
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
select {
- case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
// We should get a resolution event after 1s (default time to
// resolve as per default NDP configurations). Waiting for that
// resolution time + an extra 1s without a resolution event
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil {
t.Fatal(err)
@@ -3270,8 +3310,9 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
const nicID = 1
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
+ clock := faketime.NewManualClock()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
@@ -3280,6 +3321,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
}
e := channel.New(dadTransmits, 1280, linkAddr1)
@@ -3324,13 +3366,14 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Wait for DAD to resolve.
+ clock.Advance(dadTransmits * retransmitTimer)
select {
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("timed out waiting for DAD resolution")
}
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
@@ -3837,8 +3880,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
// TestAddRoute tests Stack.AddRoute
func TestAddRoute(t *testing.T) {
- const nicID = 1
-
s := stack.New(stack.Options{})
subnet1, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3875,8 +3916,6 @@ func TestAddRoute(t *testing.T) {
// TestRemoveRoutes tests Stack.RemoveRoutes
func TestRemoveRoutes(t *testing.T) {
- const nicID = 1
-
s := stack.New(stack.Options{})
addressToRemove := tcpip.Address("\x01")
@@ -4223,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
- if r != nil {
+ if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
@@ -4394,7 +4433,7 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) {
if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
} else if diff := cmp.Diff(
- []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}},
+ []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}},
neighbors,
); diff != "" {
t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff)
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index ddff6e2d6..90a8ba6cf 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -39,7 +39,7 @@ type TCPCubicState struct {
WMax float64
// T is the time when the current congestion avoidance was entered.
- T time.Time `state:".(unixTime)"`
+ T tcpip.MonotonicTime
// TimeSinceLastCongestion denotes the time since the current
// congestion avoidance was entered.
@@ -78,7 +78,7 @@ type TCPCubicState struct {
type TCPRACKState struct {
// XmitTime is the transmission timestamp of the most recent
// acknowledged segment.
- XmitTime time.Time `state:".(unixTime)"`
+ XmitTime tcpip.MonotonicTime
// EndSequence is the ending TCP sequence number of the most recent
// acknowledged segment.
@@ -216,7 +216,7 @@ type TCPRTTState struct {
// +stateify savable
type TCPSenderState struct {
// LastSendTime is the timestamp at which we sent the last segment.
- LastSendTime time.Time `state:".(unixTime)"`
+ LastSendTime tcpip.MonotonicTime
// DupAckCount is the number of Duplicate ACKs received. It is used for
// fast retransmit.
@@ -256,7 +256,7 @@ type TCPSenderState struct {
RTTMeasureSeqNum seqnum.Value
// RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
- RTTMeasureTime time.Time `state:".(unixTime)"`
+ RTTMeasureTime tcpip.MonotonicTime
// Closed indicates that the caller has closed the endpoint for
// sending.
@@ -313,7 +313,7 @@ type TCPSACKInfo struct {
type RcvBufAutoTuneParams struct {
// MeasureTime is the time at which the current measurement was
// started.
- MeasureTime time.Time `state:".(unixTime)"`
+ MeasureTime tcpip.MonotonicTime
// CopiedBytes is the number of bytes copied to user space since this
// measure began.
@@ -341,7 +341,7 @@ type RcvBufAutoTuneParams struct {
// RTTMeasureTime is the absolute time at which the current RTT
// measurement period began.
- RTTMeasureTime time.Time `state:".(unixTime)"`
+ RTTMeasureTime tcpip.MonotonicTime
// Disabled is true if an explicit receive buffer is set for the
// endpoint.
@@ -380,9 +380,6 @@ type TCPSndBufState struct {
// SndClosed indicates that the endpoint has been closed for sends.
SndClosed bool
- // SndBufInQueue is the number of bytes in the send queue.
- SndBufInQueue seqnum.Size
-
// PacketTooBigCount is used to notify the main protocol routine how
// many times a "packet too big" control packet is received.
PacketTooBigCount int
@@ -429,7 +426,7 @@ type TCPEndpointState struct {
ID TCPEndpointID
// SegTime denotes the absolute time when this segment was received.
- SegTime time.Time `state:".(unixTime)"`
+ SegTime tcpip.MonotonicTime
// RcvBufState contains information about the state of the endpoint's
// receive socket buffer.
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 80ad1a9d4..dda57e225 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -16,7 +16,6 @@ package stack
import (
"fmt"
- "math/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -217,13 +216,20 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
netProto: netProto,
transProto: transProto,
}
- epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- return multiPortEp.singleRegisterEndpoint(t, flags)
+ if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil {
+ return err
+ }
+ // Only add this newly created multiportEndpoint if the singleRegisterEndpoint
+ // succeeded.
+ if !ok {
+ epsByNIC.endpoints[bindToDevice] = multiPortEp
+ }
+ return nil
}
-func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
+func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -407,7 +413,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
-
bits := flags.Bits() & ports.MultiBindFlagMask
if len(ep.endpoints) != 0 {
@@ -470,17 +475,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
-
epsByNIC, ok := eps.endpoints[id]
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ seed: d.stack.Seed(),
}
+ }
+ if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
+ return err
+ }
+ // Only add this newly created epsByNIC if registerEndpoint succeeded.
+ if !ok {
eps.endpoints[id] = epsByNIC
}
-
- return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
+ return nil
}
func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
@@ -502,7 +511,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum
return nil
}
- return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
+ return epsByNIC.checkEndpoint(flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 4848495c9..45b09110d 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -18,6 +18,7 @@ import (
"io/ioutil"
"math"
"math/rand"
+ "strconv"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -84,7 +85,8 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
type headers struct {
- srcPort, dstPort uint16
+ srcPort uint16
+ dstPort uint16
}
func newPayload() []byte {
@@ -201,6 +203,56 @@ func TestTransportDemuxerRegister(t *testing.T) {
}
}
+func TestTransportDemuxerRegisterMultiple(t *testing.T) {
+ type test struct {
+ flags ports.Flags
+ want tcpip.Error
+ }
+ for _, subtest := range []struct {
+ name string
+ tests []test
+ }{
+ {"zeroFlags", []test{
+ {ports.Flags{}, nil},
+ {ports.Flags{}, &tcpip.ErrPortInUse{}},
+ }},
+ {"multibindFlags", []test{
+ // Allow multiple registrations same TransportEndpointID with multibind flags.
+ {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
+ {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
+ // Disallow registration w/same ID for a non-multibindflag.
+ {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}},
+ }},
+ } {
+ t.Run(subtest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ var eps []tcpip.Endpoint
+ for idx, test := range subtest.tests {
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ eps = append(eps, ep)
+ tEP, ok := ep.(stack.TransportEndpoint)
+ if !ok {
+ t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
+ }
+ id := stack.TransportEndpointID{LocalPort: 1}
+ if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want {
+ t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want)
+ }
+ }
+ for _, ep := range eps {
+ ep.Close()
+ }
+ })
+ }
+}
+
// TestBindToDeviceDistribution injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
func TestBindToDeviceDistribution(t *testing.T) {
@@ -208,7 +260,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
reuse bool
bindToDevice tcpip.NICID
}
- for _, test := range []struct {
+ tcs := []struct {
name string
// endpoints will received the inject packets.
endpoints []endpointSockopts
@@ -217,29 +269,29 @@ func TestBindToDeviceDistribution(t *testing.T) {
wantDistributions map[tcpip.NICID][]float64
}{
{
- "BindPortReuse",
+ name: "BindPortReuse",
// 5 endpoints that all have reuse set.
- []endpointSockopts{
+ endpoints: []endpointSockopts{
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
{reuse: true, bindToDevice: 0},
},
- map[tcpip.NICID][]float64{
+ wantDistributions: map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
- "BindToDevice",
+ name: "BindToDevice",
// 3 endpoints with various bindings.
- []endpointSockopts{
+ endpoints: []endpointSockopts{
{reuse: false, bindToDevice: 1},
{reuse: false, bindToDevice: 2},
{reuse: false, bindToDevice: 3},
},
- map[tcpip.NICID][]float64{
+ wantDistributions: map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
@@ -249,9 +301,9 @@ func TestBindToDeviceDistribution(t *testing.T) {
},
},
{
- "ReuseAndBindToDevice",
+ name: "ReuseAndBindToDevice",
// 6 endpoints with various bindings.
- []endpointSockopts{
+ endpoints: []endpointSockopts{
{reuse: true, bindToDevice: 1},
{reuse: true, bindToDevice: 1},
{reuse: true, bindToDevice: 2},
@@ -259,7 +311,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
{reuse: true, bindToDevice: 2},
{reuse: true, bindToDevice: 0},
},
- map[tcpip.NICID][]float64{
+ wantDistributions: map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
1: {0.5, 0.5, 0, 0, 0, 0},
@@ -270,35 +322,42 @@ func TestBindToDeviceDistribution(t *testing.T) {
1000: {0, 0, 0, 0, 0, 1},
},
},
- } {
- for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
- "IPv4": ipv4.ProtocolNumber,
- "IPv6": ipv6.ProtocolNumber,
- } {
+ }
+ protos := map[string]tcpip.NetworkProtocolNumber{
+ "IPv4": ipv4.ProtocolNumber,
+ "IPv6": ipv6.ProtocolNumber,
+ }
+
+ for _, test := range tcs {
+ for protoName, protoNum := range protos {
for device, wantDistribution := range test.wantDistributions {
- t.Run(test.name+protoName+string(device), func(t *testing.T) {
+ t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) {
+ // Create the NICs.
var devices []tcpip.NICID
for d := range test.wantDistributions {
devices = append(devices, d)
}
c := newDualTestContextMultiNIC(t, defaultMTU, devices)
+ // Create endpoints and bind each to a NIC, sometimes reusing ports.
eps := make(map[tcpip.Endpoint]int)
-
pollChannel := make(chan tcpip.Endpoint)
for i, endpoint := range test.endpoints {
// Try to receive the data.
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ close(ch)
+ })
var err tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
+ t.Cleanup(ep.Close)
eps[ep] = i
go func(ep tcpip.Endpoint) {
@@ -307,32 +366,34 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
}(ep)
- defer ep.Close()
ep.SocketOptions().SetReusePort(endpoint.reuse)
if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
}
var dstAddr tcpip.Address
- switch netProtoNum {
+ switch protoNum {
case ipv4.ProtocolNumber:
dstAddr = testDstAddrV4
case ipv6.ProtocolNumber:
dstAddr = testDstAddrV6
default:
- t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ t.Fatalf("unexpected protocol number: %d", protoNum)
}
if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
- npackets := 100000
- nports := 10000
+ // Send packets across a range of ports, checking that packets from
+ // the same source port are always demultiplexed to the same
+ // destination endpoint.
+ npackets := 10_000
+ nports := 1_000
if got, want := len(test.endpoints), len(wantDistribution); got != want {
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
}
- ports := make(map[uint16]tcpip.Endpoint)
+ endpoints := make(map[uint16]tcpip.Endpoint)
stats := make(map[tcpip.Endpoint]int)
for i := 0; i < npackets; i++ {
// Send a packet.
@@ -342,13 +403,13 @@ func TestBindToDeviceDistribution(t *testing.T) {
srcPort: testSrcPort + port,
dstPort: testDstPort,
}
- switch netProtoNum {
+ switch protoNum {
case ipv4.ProtocolNumber:
c.sendV4Packet(payload, hdrs, device)
case ipv6.ProtocolNumber:
c.sendV6Packet(payload, hdrs, device)
default:
- t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ t.Fatalf("unexpected protocol number: %d", protoNum)
}
ep := <-pollChannel
@@ -357,11 +418,11 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
stats[ep]++
if i < nports {
- ports[uint16(i)] = ep
+ endpoints[uint16(i)] = ep
} else {
// Check that all packets from one client are handled by the same
// socket.
- if want, got := ports[port], ep; want != got {
+ if want, got := endpoints[port], ep; want != got {
t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
}
}
diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go
index 7ce43a68e..371da2f40 100644
--- a/pkg/tcpip/stdclock.go
+++ b/pkg/tcpip/stdclock.go
@@ -60,11 +60,11 @@ type stdClock struct {
// monotonicOffset is assigned maxMonotonic after restore so that the
// monotonic time will continue from where it "left off" before saving as part
// of S/R.
- monotonicOffset int64 `state:"nosave"`
+ monotonicOffset MonotonicTime `state:"nosave"`
// monotonicMU protects maxMonotonic.
monotonicMU sync.Mutex `state:"nosave"`
- maxMonotonic int64
+ maxMonotonic MonotonicTime
}
// NewStdClock returns an instance of a clock that uses the time package.
@@ -76,25 +76,25 @@ func NewStdClock() Clock {
var _ Clock = (*stdClock)(nil)
-// NowNanoseconds implements Clock.NowNanoseconds.
-func (*stdClock) NowNanoseconds() int64 {
- return time.Now().UnixNano()
+// Now implements Clock.Now.
+func (*stdClock) Now() time.Time {
+ return time.Now()
}
// NowMonotonic implements Clock.NowMonotonic.
-func (s *stdClock) NowMonotonic() int64 {
+func (s *stdClock) NowMonotonic() MonotonicTime {
sinceBase := time.Since(s.baseTime)
if sinceBase < 0 {
panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime))
}
- monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset
+ monotonicValue := s.monotonicOffset.Add(sinceBase)
s.monotonicMU.Lock()
defer s.monotonicMU.Unlock()
// Monotonic time values must never decrease.
- if monotonicValue > s.maxMonotonic {
+ if s.maxMonotonic.Before(monotonicValue) {
s.maxMonotonic = monotonicValue
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 797778e08..8f2658f64 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -64,17 +64,47 @@ func (e *ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
+// MonotonicTime is a monotonic clock reading.
+//
+// +stateify savable
+type MonotonicTime struct {
+ nanoseconds int64
+}
+
+// Before reports whether the monotonic clock reading mt is before u.
+func (mt MonotonicTime) Before(u MonotonicTime) bool {
+ return mt.nanoseconds < u.nanoseconds
+}
+
+// After reports whether the monotonic clock reading mt is after u.
+func (mt MonotonicTime) After(u MonotonicTime) bool {
+ return mt.nanoseconds > u.nanoseconds
+}
+
+// Add returns the monotonic clock reading mt+d.
+func (mt MonotonicTime) Add(d time.Duration) MonotonicTime {
+ return MonotonicTime{
+ nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(),
+ }
+}
+
+// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum)
+// value that can be stored in a Duration, the maximum (or minimum) duration
+// will be returned. To compute t-d for a duration d, use t.Add(-d).
+func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration {
+ return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds))
+}
+
// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
type Clock interface {
- // NowNanoseconds returns the current real time as a number of
- // nanoseconds since the Unix epoch.
- NowNanoseconds() int64
+ // Now returns the current local time.
+ Now() time.Time
- // NowMonotonic returns a monotonic time value at nanosecond resolution.
- NowMonotonic() int64
+ // NowMonotonic returns the current monotonic clock reading.
+ NowMonotonic() MonotonicTime
// AfterFunc waits for the duration to elapse and then calls f in its own
// goroutine. It returns a Timer that can be used to cancel the call using
@@ -435,11 +465,11 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns UID of the packet.
- UID() uint32
+ // UID returns KUID of the packet.
+ KUID() uint32
- // GID returns GID of the packet.
- GID() uint32
+ // GID returns KGID of the packet.
+ KGID() uint32
}
// ReadOptions contains options for Endpoint.Read.
@@ -861,6 +891,9 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
+// EndpointState represents the state of an endpoint.
+type EndpointState uint8
+
// CongestionControlState indicates the current congestion control state for
// TCP sender.
type CongestionControlState int
@@ -897,6 +930,9 @@ type TCPInfoOption struct {
// RTO is the retransmission timeout for the endpoint.
RTO time.Duration
+ // State is the current endpoint protocol state.
+ State EndpointState
+
// CcState is the congestion control state.
CcState CongestionControlState
@@ -1552,6 +1588,10 @@ type IPForwardingStats struct {
// were too big for the outgoing MTU.
PacketTooBig *StatCounter
+ // HostUnreachable is the number of IP packets received which could not be
+ // successfully forwarded due to an unresolvable next hop.
+ HostUnreachable *StatCounter
+
// ExtensionHeaderProblem is the number of IP packets which were dropped
// because of a problem encountered when processing an IPv6 extension
// header.
@@ -1835,37 +1875,104 @@ type UDPStats struct {
ChecksumErrors *StatCounter
}
+// NICNeighborStats holds metrics for the neighbor table.
+type NICNeighborStats struct {
+ // LINT.IfChange(NICNeighborStats)
+
+ // UnreachableEntryLookups counts the number of lookups performed on an
+ // entry in Unreachable state.
+ UnreachableEntryLookups *StatCounter
+
+ // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats)
+}
+
+// NICPacketStats holds basic packet statistics.
+type NICPacketStats struct {
+ // LINT.IfChange(NICPacketStats)
+
+ // Packets is the number of packets counted.
+ Packets *StatCounter
+
+ // Bytes is the number of bytes counted.
+ Bytes *StatCounter
+
+ // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats)
+}
+
+// NICStats holds NIC statistics.
+type NICStats struct {
+ // LINT.IfChange(NICStats)
+
+ // UnknownL3ProtocolRcvdPackets is the number of packets received that were
+ // for an unknown or unsupported network protocol.
+ UnknownL3ProtocolRcvdPackets *StatCounter
+
+ // UnknownL4ProtocolRcvdPackets is the number of packets received that were
+ // for an unknown or unsupported transport protocol.
+ UnknownL4ProtocolRcvdPackets *StatCounter
+
+ // MalformedL4RcvdPackets is the number of packets received by a NIC that
+ // could not be delivered to a transport endpoint because the L4 header could
+ // not be parsed.
+ MalformedL4RcvdPackets *StatCounter
+
+ // Tx contains statistics about transmitted packets.
+ Tx NICPacketStats
+
+ // Rx contains statistics about received packets.
+ Rx NICPacketStats
+
+ // DisabledRx contains statistics about received packets on disabled NICs.
+ DisabledRx NICPacketStats
+
+ // Neighbor contains statistics about neighbor entries.
+ Neighbor NICNeighborStats
+
+ // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats)
+}
+
+// FillIn returns a copy of s with nil fields initialized to new StatCounters.
+func (s NICStats) FillIn() NICStats {
+ InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
+}
+
// Stats holds statistics about the networking stack.
-//
-// All fields are optional.
type Stats struct {
- // UnknownProtocolRcvdPackets is the number of packets received by the
- // stack that were for an unknown or unsupported protocol.
- UnknownProtocolRcvdPackets *StatCounter
+ // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less
+ // ambiguous.
- // MalformedRcvdPackets is the number of packets received by the stack
- // that were deemed malformed.
- MalformedRcvdPackets *StatCounter
-
- // DroppedPackets is the number of packets dropped due to full queues.
+ // DroppedPackets is the number of packets dropped at the transport layer.
DroppedPackets *StatCounter
- // ICMP breaks out ICMP-specific stats (both v4 and v6).
+ // NICs is an aggregation of every NIC's statistics. These should not be
+ // incremented using this field, but using the relevant NIC multicounters.
+ NICs NICStats
+
+ // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4
+ // and v6). These should not be incremented using this field, but using the
+ // relevant NetworkEndpoint ICMP multicounters.
ICMP ICMPStats
- // IGMP breaks out IGMP-specific stats.
+ // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These
+ // should not be incremented using this field, but using the relevant
+ // NetworkEndpoint IGMP multicounters.
IGMP IGMPStats
- // IP breaks out IP-specific stats (both v4 and v6).
+ // IP is an aggregation of every NetworkEndpoint's IP statistics. These should
+ // not be incremented using this field, but using the relevant NetworkEndpoint
+ // IP multicounters.
IP IPStats
- // ARP breaks out ARP-specific stats.
+ // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These
+ // should not be incremented using this field, but using the relevant
+ // NetworkEndpoint ARP multicounters.
ARP ARPStats
- // TCP breaks out TCP-specific stats.
+ // TCP holds TCP-specific stats.
TCP TCPStats
- // UDP breaks out UDP-specific stats.
+ // UDP holds UDP-specific stats.
UDP UDPStats
}
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 269081ff8..c96ae2f02 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -19,7 +19,6 @@ import (
"fmt"
"io"
"net"
- "strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) {
}
}
-func TestStatsString(t *testing.T) {
- got := fmt.Sprintf("%+v", Stats{}.FillIn())
-
- matchers := []string{
- // Print root-level stats correctly.
- "UnknownProtocolRcvdPackets:0",
- // Print protocol-specific stats correctly.
- "TCP:{ActiveConnectionOpenings:0",
- }
-
- for _, m := range matchers {
- if !strings.Contains(got, m) {
- t.Errorf("string.Contains(got, %q) = false", m)
- }
- }
- if t.Failed() {
- t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got)
- }
-}
-
func TestAddressWithPrefixSubnet(t *testing.T) {
tests := []struct {
addr Address
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index ab2dab60c..181ef799e 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -53,12 +53,14 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 07ba2b837..f9ab7d0af 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
// Make sure the packet is not dropped by the next rule.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr
checker.ICMPv6Type(header.ICMPv6EchoReply)))
}
+func boolToInt(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
+func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[hook]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+}
+
func TestForwardingHook(t *testing.T) {
const (
nicID1 = 1
@@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) {
},
}
- setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
- return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
-
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, ipv6)
- ruleIdx := filter.BuiltinChains[stack.Forward]
- filter.Rules[ruleIdx].Filter = f
- filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
- if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
- }
- }
- }
-
- boolToInt := func(v bool) uint64 {
- if v {
- return 1
- }
- return 0
- }
-
subTests := []struct {
name string
setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
@@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) {
{
name: "Drop",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
expectForward: false,
},
{
name: "Drop with input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
expectForward: false,
},
{
name: "Drop with output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with other input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
expectForward: true,
},
{
name: "Drop with input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with inverted input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
expectForward: true,
},
{
name: "Drop with inverted output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
expectForward: true,
},
}
@@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) {
})
}
}
+
+func TestInputHookWithLocalForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ rx func(*channel.Endpoint)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+ },
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv6Addr),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectDrop bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on arrival NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on delivered NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop with input NIC filtering on other NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectDrop: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ })
+
+ test.rx(e1)
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
+ t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip2Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ if p, ok := e1.Read(); ok == subTest.expectDrop {
+ t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
+ } else if !subTest.expectDrop {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ if p, ok := e2.Read(); ok {
+ t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index c657714ba..27caa0c28 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -17,6 +17,8 @@ package link_resolution_test
import (
"bytes"
"fmt"
+ "net"
+ "runtime"
"testing"
"time"
@@ -27,12 +29,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -283,9 +287,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: clock,
}
host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -329,7 +335,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
// Wait for an error due to link resolution failing, or the endpoint to be
// writable.
+ if test.expectedWriteErr != nil {
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+ } else {
+ clock.RunImmediatelyScheduledJobs()
+ }
<-ch
+
{
var r bytes.Reader
r.Reset([]byte{0})
@@ -395,6 +411,242 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
}
}
+func TestForwardingWithLinkResolutionFailure(t *testing.T) {
+ const (
+ incomingNICID = 1
+ outgoingNICID = 2
+ ttl = 2
+ expectedHostUnreachableErrorCount = 1
+ )
+ outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06")
+
+ rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+ }
+
+ rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+ }
+
+ arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
+ if request.Proto != arp.ProtocolNumber {
+ t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber)
+ }
+ if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(request.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != src {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst)
+ }
+ }
+
+ ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
+ if request.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber)
+ }
+
+ snmc := header.SolicitedNodeAddr(dst)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()),
+ checker.SrcAddr(src),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(dst),
+ ))
+ }
+
+ icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ipv4.DefaultTTL),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4HostUnreachable),
+ ),
+ )
+ }
+
+ icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ipv6.DefaultTTL),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6AddressUnreachable),
+ ),
+ )
+ }
+
+ tests := []struct {
+ name string
+ networkProtocolFactory []stack.NetworkProtocolFactory
+ networkProtocolNumber tcpip.NetworkProtocolNumber
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ incomingAddr tcpip.AddressWithPrefix
+ outgoingAddr tcpip.AddressWithPrefix
+ transportProtocol func(*stack.Stack) stack.TransportProtocol
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address)
+ icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address)
+ mtu uint32
+ }{
+ {
+ name: "IPv4 Host unreachable",
+ networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ networkProtocolNumber: header.IPv4ProtocolNumber,
+ sourceAddr: tcptestutil.MustParse4("10.0.0.2"),
+ destAddr: tcptestutil.MustParse4("11.0.0.2"),
+ incomingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ outgoingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ transportProtocol: icmp.NewProtocol4,
+ linkResolutionRequestChecker: arpChecker,
+ icmpReplyChecker: icmpv4Checker,
+ rx: rxICMPv4EchoRequest,
+ mtu: ipv4.MaxTotalSize,
+ },
+ {
+ name: "IPv6 Host unreachable",
+ networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ networkProtocolNumber: header.IPv6ProtocolNumber,
+ sourceAddr: tcptestutil.MustParse6("10::2"),
+ destAddr: tcptestutil.MustParse6("11::2"),
+ incomingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ outgoingAddr: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11::1").To16()),
+ PrefixLen: 64,
+ },
+ transportProtocol: icmp.NewProtocol6,
+ linkResolutionRequestChecker: ndpChecker,
+ icmpReplyChecker: icmpv6Checker,
+ rx: rxICMPv6EchoRequest,
+ mtu: header.IPv6MinimumMTU,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: test.networkProtocolFactory,
+ TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol},
+ Clock: clock,
+ })
+
+ // Set up endpoint through which we will receive packets.
+ incomingEndpoint := channel.New(1, test.mtu, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
+ }
+ incomingProtoAddr := tcpip.ProtocolAddress{
+ Protocol: test.networkProtocolNumber,
+ AddressWithPrefix: test.incomingAddr,
+ }
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ }
+
+ // Set up endpoint through which we will attempt to forward packets.
+ outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr)
+ outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
+ }
+ outgoingProtoAddr := tcpip.ProtocolAddress{
+ Protocol: test.networkProtocolNumber,
+ AddressWithPrefix: test.outgoingAddr,
+ }
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: test.incomingAddr.Subnet(),
+ NIC: incomingNICID,
+ },
+ {
+ Destination: test.outgoingAddr.Subnet(),
+ NIC: outgoingNICID,
+ },
+ })
+
+ if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err)
+ }
+
+ test.rx(incomingEndpoint, test.sourceAddr, test.destAddr)
+
+ nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err)
+ }
+ // Trigger the first packet on the endpoint.
+ clock.RunImmediatelyScheduledJobs()
+
+ for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ {
+ request, ok := outgoingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ARP packet through outgoing NIC")
+ }
+
+ test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr)
+
+ // Advance the clock the span of one request timeout.
+ clock.Advance(nudConfigs.RetransmitTimer)
+ }
+
+ // Next, we make a blocking read to retrieve the error packet. This is
+ // necessary because outgoing packets are dequeued asynchronously when
+ // link resolution fails, and this dequeue is what triggers the ICMP
+ // error.
+ reply, ok := incomingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ICMP packet through incoming NIC")
+ }
+
+ test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr)
+
+ // Since link resolution failed, we don't expect the packet to be
+ // forwarded.
+ forwardedPacket, ok := outgoingEndpoint.Read()
+ if ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want {
+ t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
func TestGetLinkAddress(t *testing.T) {
const (
host1NICID = 1
@@ -449,8 +701,10 @@ func TestGetLinkAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ Clock: clock,
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -466,8 +720,20 @@ func TestGetLinkAddress(t *testing.T) {
if test.expectedErr == nil {
wantRes.LinkAddress = utils.LinkAddr2
}
- if diff := cmp.Diff(wantRes, <-ch); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+ select {
+ case got := <-ch:
+ if diff := cmp.Diff(wantRes, got); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("event didn't arrive")
}
})
}
@@ -544,8 +810,10 @@ func TestRouteResolvedFields(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ Clock: clock,
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
@@ -575,8 +843,20 @@ func TestRouteResolvedFields(t *testing.T) {
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
}
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
+ }
+ clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
+
+ select {
+ case got := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatalf("event didn't arrive")
}
if test.expectedErr != nil {
@@ -789,11 +1069,16 @@ func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
d.c <- e
}
-func (d *nudDispatcher) waitForEvent(want eventInfo) error {
- if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+func (d *nudDispatcher) expectEvent(want eventInfo) error {
+ select {
+ case got := <-d.c:
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+ }
+ return nil
+ default:
+ return fmt.Errorf("event didn't arrive")
}
- return nil
}
// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
@@ -804,7 +1089,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
neighborAddr tcpip.Address
- getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{})
+ getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{})
isHost1Listener bool
}{
{
@@ -812,23 +1097,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -836,23 +1123,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -860,23 +1149,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -884,23 +1175,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
},
{
@@ -908,23 +1201,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -933,23 +1228,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -958,23 +1255,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv4.ProtocolNumber,
remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -983,23 +1282,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
netProto: ipv6.ProtocolNumber,
remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
var listenerWQ waiter.Queue
+ listenerWE, listenerCH := waiter.NewChannelEntry(nil)
+ listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
if err != nil {
t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
+ t.Cleanup(listenerEP.Close)
var clientWQ waiter.Queue
clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.WritableEvents)
+ clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
if err != nil {
- listenerEP.Close()
t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
}
- return listenerEP, clientEP, clientCH
+ return listenerEP, listenerCH, clientEP, clientCH
},
isHost1Listener: true,
},
@@ -1037,14 +1338,14 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
}
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryAdded,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
}); err != nil {
t.Fatalf("error waiting for initial NUD event: %s", err)
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1064,7 +1365,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// See NUDConfigurations.BaseReachableTime for more information.
maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
clock.Advance(maxReachableTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1072,8 +1373,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for stale NUD event: %s", err)
}
- listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
- defer listenerEP.Close()
+ listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
defer clientEP.Close()
listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
if err := listenerEP.Bind(listenerAddr); err != nil {
@@ -1094,14 +1394,15 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// with confirmation that the neighbor is reachable (indicated by a
// successful 3-way handshake).
<-clientCH
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
}); err != nil {
t.Fatalf("error waiting for delay NUD event: %s", err)
}
- if err := nudDisp.waitForEvent(eventInfo{
+ <-listenerCH
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1109,26 +1410,55 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for reachable NUD event: %s", err)
}
+ peerEP, peerWQ, err := listenerEP.Accept(nil)
+ if err != nil {
+ t.Fatalf("listenerEP.Accept(): %s", err)
+ }
+ defer peerEP.Close()
+ peerWE, peerCH := waiter.NewChannelEntry(nil)
+ peerWQ.EventRegister(&peerWE, waiter.ReadableEvents)
+
// Wait for the neighbor to be stale again then send data to the remote.
//
// On successful transmission, the neighbor should become reachable
// without probing the neighbor as a TCP ACK would be received which is an
// indication of the neighbor being reachable.
clock.Advance(maxReachableTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
}); err != nil {
t.Fatalf("error waiting for stale NUD event: %s", err)
}
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- if _, err := clientEP.Write(&r, wOpts); err != nil {
- t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := clientEP.Write(&r, wOpts); err != nil {
+ t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ }
+ }
+ // Heads up, there is a race here.
+ //
+ // Incoming TCP segments are handled in
+ // tcp.(*endpoint).handleSegmentLocked:
+ //
+ // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the
+ // segment queue and notifies waiting readers (such as this channel)
+ //
+ // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment
+ // and notifies the NUD machinery that the peer is reachable
+ //
+ // Thus we must permit a delay between the readable signal and the
+ // expected NUD event.
+ //
+ // At the time of writing, this race is reliably hit with gotsan.
+ <-peerCH
+ for len(nudDisp.c) == 0 {
+ runtime.Gosched()
}
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1141,7 +1471,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// TCP should not mark the route reachable and NUD should go through the
// probe state.
clock.Advance(nudConfigs.DelayFirstProbeTime)
- if err := nudDisp.waitForEvent(eventInfo{
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
@@ -1149,7 +1479,16 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
t.Fatalf("error waiting for probe NUD event: %s", err)
}
}
- if err := nudDisp.waitForEvent(eventInfo{
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := peerEP.Write(&r, wOpts); err != nil {
+ t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err)
+ }
+ }
+ <-clientCH
+ if err := nudDisp.expectEvent(eventInfo{
eventType: entryChanged,
nicID: utils.Host1NICID,
entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 87d36e1dd..b2008f0b2 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -44,20 +44,17 @@ type ndpDispatcher struct{}
func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
- return false
+func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) {
}
-func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {}
-func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
-func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return true
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
index f84d399fb..94b580a70 100644
--- a/pkg/tcpip/testutil/testutil.go
+++ b/pkg/tcpip/testutil/testutil.go
@@ -109,3 +109,15 @@ func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) er
return nil
}
+
+// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on
+// error.
+//
+// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
+func MustParseLink(addr string) tcpip.LinkAddress {
+ parsed, err := tcpip.ParseMACAddress(addr)
+ if err != nil {
+ panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err))
+ }
+ return parsed
+}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index 1633d0aeb..ed1ed8ac6 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -15,6 +15,7 @@
package tcpip_test
import (
+ "math"
"sync"
"testing"
"time"
@@ -22,10 +23,85 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
+func TestMonotonicTimeBefore(t *testing.T) {
+ var mt tcpip.MonotonicTime
+ if mt.Before(mt) {
+ t.Errorf("%#v.Before(%#v)", mt, mt)
+ }
+
+ one := mt.Add(1)
+ if one.Before(mt) {
+ t.Errorf("%#v.Before(%#v)", one, mt)
+ }
+ if !mt.Before(one) {
+ t.Errorf("!%#v.Before(%#v)", mt, one)
+ }
+}
+
+func TestMonotonicTimeAfter(t *testing.T) {
+ var mt tcpip.MonotonicTime
+ if mt.After(mt) {
+ t.Errorf("%#v.After(%#v)", mt, mt)
+ }
+
+ one := mt.Add(1)
+ if mt.After(one) {
+ t.Errorf("%#v.After(%#v)", mt, one)
+ }
+ if !one.After(mt) {
+ t.Errorf("!%#v.After(%#v)", one, mt)
+ }
+}
+
+func TestMonotonicTimeAddSub(t *testing.T) {
+ var mt tcpip.MonotonicTime
+ if one, two := mt.Add(2), mt.Add(1).Add(1); one != two {
+ t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two)
+ }
+
+ min := mt.Add(math.MinInt64)
+ max := mt.Add(math.MaxInt64)
+
+ if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max {
+ t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow)
+ }
+ if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min {
+ t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow)
+ }
+
+ if got, want := min.Sub(min), time.Duration(0); want != got {
+ t.Errorf("got min.Sub(min) = %d, want %d", got, want)
+ }
+ if got, want := max.Sub(max), time.Duration(0); want != got {
+ t.Errorf("got max.Sub(max) = %d, want %d", got, want)
+ }
+
+ if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want {
+ t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow)
+ }
+ if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want {
+ t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow)
+ }
+}
+
+func TestMonotonicTimeSub(t *testing.T) {
+ var mt tcpip.MonotonicTime
+
+ if one, two := mt.Add(2), mt.Add(1).Add(1); one != two {
+ t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two)
+ }
+
+ if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow {
+ t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow)
+ }
+ if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow {
+ t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow)
+ }
+}
+
const (
shortDuration = 1 * time.Nanosecond
middleDuration = 100 * time.Millisecond
- longDuration = 1 * time.Second
)
func TestJobReschedule(t *testing.T) {
@@ -53,10 +129,14 @@ func TestJobReschedule(t *testing.T) {
wg.Wait()
}
+func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) {
+ return tcpip.NewStdClock(), time.After
+}
+
func TestJobExecution(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -68,7 +148,7 @@ func TestJobExecution(t *testing.T) {
// Wait for timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -76,14 +156,14 @@ func TestJobExecution(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -99,7 +179,7 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
// Wait for timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -107,14 +187,14 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -128,7 +208,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
select {
case <-ch:
t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
job.Schedule(shortDuration)
@@ -136,7 +216,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
// Wait for timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -144,14 +224,14 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -167,14 +247,19 @@ func TestJobImmediatelyCancel(t *testing.T) {
select {
case <-ch:
t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
+func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) {
+ clock, after := stdClockWithAfter()
+ return clock, after, time.Sleep
+}
+
func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after, sleep := stdClockWithAfterAndSleep()
var lock sync.Mutex
ch := make(chan struct{})
@@ -189,7 +274,7 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
- time.Sleep(middleDuration * 2)
+ sleep(middleDuration * 2)
job.Cancel()
lock.Unlock()
}
@@ -199,14 +284,14 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
select {
case <-ch:
t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration * 2):
+ case <-after(middleDuration * 2):
}
}
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after, sleep := stdClockWithAfterAndSleep()
var lock sync.Mutex
ch := make(chan struct{})
@@ -215,7 +300,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
- time.Sleep(middleDuration)
+ sleep(middleDuration)
job.Cancel()
job.Schedule(shortDuration)
}
@@ -224,7 +309,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
// Wait for double the duration for the last timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -232,14 +317,14 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- clock := tcpip.NewStdClock()
+ clock, after := stdClockWithAfter()
var lock sync.Mutex
ch := make(chan struct{})
@@ -255,7 +340,7 @@ func TestManyJobReschedulesUnderLock(t *testing.T) {
// Wait for double the duration for the last timer to fire.
select {
case <-ch:
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
t.Fatal("timed out waiting for timer to fire")
}
@@ -263,6 +348,6 @@ func TestManyJobReschedulesUnderLock(t *testing.T) {
select {
case <-ch:
t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
+ case <-after(middleDuration):
}
}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 7e5c79776..bbc0e3ecc 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -38,3 +38,22 @@ go_library(
"//pkg/waiter",
],
)
+
+go_test(
+ name = "icmp_x_test",
+ size = "small",
+ srcs = ["icmp_test.go"],
+ deps = [
+ ":icmp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 8afde7fca..cb316d27a 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -16,6 +16,7 @@ package icmp
import (
"io"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,14 +27,12 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// TODO(https://gvisor.dev/issues/5623): Unit test this package.
-
// +stateify savable
type icmpPacket struct {
icmpPacketEntry
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ receivedAt time.Time `state:".(int64)"`
}
type endpointState int
@@ -133,7 +132,8 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
}
// Close the receive list and drain it.
@@ -160,7 +160,7 @@ func (e *endpoint) Close() {
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (*endpoint) ModerateRecvBuf(int) {}
// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
@@ -193,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: p.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.timestamp,
+ Timestamp: p.receivedAt.UnixNano(),
},
}
if opts.NeedRemoteAddr {
@@ -304,6 +304,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
return 0, &tcpip.ErrNoRoute{}
@@ -348,8 +351,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return int64(len(v)), nil
}
+var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
+
+// HasNIC implements tcpip.SocketOptionsHandler.
+func (e *endpoint) HasNIC(id int32) bool {
+ return e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
return nil
}
@@ -390,7 +400,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
@@ -606,18 +616,19 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
return id, err
}
// We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
@@ -747,8 +758,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -756,8 +765,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -800,7 +807,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.Clock().NowNanoseconds()
+ packet.receivedAt = e.stack.Clock().Now()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 28a56a2d5..b8b839e4a 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,23 @@
package icmp
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *icmpPacket) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *icmpPacket) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves icmpPacket.data field.
func (p *icmpPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
new file mode 100644
index 000000000..cc950cbde
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -0,0 +1,235 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package.
+// See the issue for remaining areas of work.
+
+var (
+ localV4Addr1 = testutil.MustParse4("10.0.0.1")
+ localV4Addr2 = testutil.MustParse4("10.0.0.2")
+ remoteV4Addr = testutil.MustParse4("10.0.0.3")
+)
+
+func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint {
+ t.Helper()
+
+ ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */)
+ t.Cleanup(ep.Close)
+
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+
+ opts := stack.NICOptions{Name: name}
+ if err := s.CreateNICWithOptions(id, wep, opts); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
+ }
+
+ if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ }
+
+ s.AddRoute(tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: id,
+ })
+
+ return ep
+}
+
+func writePayload(buf []byte) {
+ for i := range buf {
+ buf[i] = byte(i)
+ }
+}
+
+func newICMPv4EchoRequest(payloadSize uint32) buffer.View {
+ buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize))
+ writePayload(buf[header.ICMPv4MinimumSize:])
+
+ icmp := header.ICMPv4(buf)
+ icmp.SetType(header.ICMPv4Echo)
+ // No need to set the checksum; it is reset by the socket before the packet
+ // is sent.
+
+ return buf
+}
+
+// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket
+// when SO_BINDTODEVICE is set to the non-default NIC for that subnet.
+//
+// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to
+// the version of IP.
+func TestWriteUnboundWithBindToDevice(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ HandleLocal: true,
+ })
+
+ // Add two NICs, both with default routes on the same subnet. The first NIC
+ // added will be the default NIC for that subnet.
+ defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1)
+ alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2)
+
+ socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err)
+ }
+ defer socket.Close()
+
+ echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
+
+ // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC
+ // to be added is the default NIC to send packets when not explicitly bound.
+ {
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was sent out the default NIC.
+ p, ok := defaultEP.Read()
+ if !ok {
+ t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr1),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+
+ // Verify the packet was not sent out the alternate NIC.
+ if p, ok := alternateEP.Read(); ok {
+ t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
+ }
+ }
+
+ // Send a packet with SO_BINDTODEVICE. This exercises reliance on
+ // SO_BINDTODEVICE to route the packet to the alternate NIC.
+ {
+ // Use SO_BINDTODEVICE to send over the alternate NIC by default.
+ socket.SocketOptions().SetBindToDevice(2)
+
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was not sent out the default NIC.
+ if p, ok := defaultEP.Read(); ok {
+ t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p)
+ }
+
+ // Verify the packet was sent out the alternate NIC.
+ p, ok := alternateEP.Read()
+ if !ok {
+ t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr2),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+ }
+
+ // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing
+ // the device binding will fallback to using the default NIC to send
+ // packets.
+ {
+ socket.SocketOptions().SetBindToDevice(0)
+
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was sent out the default NIC.
+ p, ok := defaultEP.Read()
+ if !ok {
+ t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr1),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+
+ // Verify the packet was not sent out the alternate NIC.
+ if p, ok := alternateEP.Read(); ok {
+ t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 47f7dd1cb..fa82affc1 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -123,8 +123,6 @@ func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
- //
// Right now, the Parse() method is tied to enabled protocols passed into
// stack.New. This works for UDP and TCP, but we handle ICMP traffic even
// when netstack users don't pass ICMP as a supported protocol.
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 496eca581..8e7bb6c6e 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -27,6 +27,7 @@ package packet
import (
"fmt"
"io"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -41,9 +42,8 @@ type packet struct {
packetEntry
// data holds the actual packet data, including any headers and
// payload.
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // timestampNS is the unix time at which the packet was received.
- timestampNS int64
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
// packetInfo holds additional information like the protocol
@@ -159,7 +159,7 @@ func (ep *endpoint) Close() {
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (ep *endpoint) ModerateRecvBuf(copied int) {}
+func (*endpoint) ModerateRecvBuf(int) {}
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
@@ -189,7 +189,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
Total: packet.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: packet.timestampNS,
+ Timestamp: packet.receivedAt.UnixNano(),
},
}
if opts.NeedRemoteAddr {
@@ -208,7 +208,6 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
}
func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
- // TODO(gvisor.dev/issue/173): Implement.
return 0, &tcpip.ErrInvalidOptionValue{}
}
@@ -220,19 +219,19 @@ func (*endpoint) Disconnect() tcpip.Error {
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
// connected, and this function always returnes *tcpip.ErrNotSupported.
-func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
// with Shutdown, and this function always returns *tcpip.ErrNotSupported.
-func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
+func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
// Listen, and this function always returns *tcpip.ErrNotSupported.
-func (*endpoint) Listen(backlog int) tcpip.Error {
+func (*endpoint) Listen(int) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -244,8 +243,6 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- // TODO(gvisor.dev/issue/173): Add Bind support.
-
// "By default, all packets of the specified protocol type are passed
// to a packet socket. To get packets only from a specific interface
// use bind(2) specifying an address in a struct sockaddr_ll to bind
@@ -318,7 +315,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
@@ -339,7 +336,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -385,7 +382,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(gvisor.dev/issue/173): Return network protocol.
if !pkt.LinkHeader().View().IsEmpty() {
// Get info directly from the ethernet header.
hdr := header.Ethernet(pkt.LinkHeader().View())
@@ -424,7 +420,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
default:
panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
}
-
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
@@ -451,7 +446,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
}
- packet.timestampNS = ep.stack.Clock().NowNanoseconds()
+ packet.receivedAt = ep.stack.Clock().Now()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
@@ -484,7 +479,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
}
// SetOwner implements tcpip.Endpoint.SetOwner.
-func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (*endpoint) SetOwner(tcpip.PacketOwner) {}
// SocketOptions implements tcpip.Endpoint.SocketOptions.
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 5bd860d20..e729921db 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,11 +15,23 @@
package packet
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *packet) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *packet) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves packet.data field.
func (p *packet) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index bcec3d2e7..b6687911a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"io"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -41,9 +42,8 @@ type rawPacket struct {
rawPacketEntry
// data holds the actual packet data, including any headers and
// payload.
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // timestampNS is the unix time at which the packet was received.
- timestampNS int64
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
}
@@ -183,7 +183,7 @@ func (e *endpoint) Close() {
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (e *endpoint) ModerateRecvBuf(copied int) {}
+func (*endpoint) ModerateRecvBuf(int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.mu.Lock()
@@ -219,7 +219,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: pkt.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: pkt.timestampNS,
+ Timestamp: pkt.receivedAt.UnixNano(),
},
}
if opts.NeedRemoteAddr {
@@ -286,26 +286,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return nil, nil, nil, &tcpip.ErrBadBuffer{}
}
- // If this is an unassociated socket and callee provided a nonzero
- // destination address, route using that address.
- if e.ops.GetHeaderIncluded() {
- ip := header.IPv4(payloadBytes)
- if !ip.IsValid(len(payloadBytes)) {
- return nil, nil, nil, &tcpip.ErrInvalidOptionValue{}
- }
- dstAddr := ip.DestinationAddress()
- // Update dstAddr with the address in the IP header, unless
- // opts.To is set (e.g. if sendto specifies a specific
- // address).
- if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
- opts.To = &tcpip.FullAddress{
- NIC: 0, // NIC is unset.
- Addr: dstAddr, // The address from the payload.
- Port: 0, // There are no ports here.
- }
- }
- }
-
// Did the user caller provide a destination? If not, use the connected
// destination.
if opts.To == nil {
@@ -402,7 +382,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
}
// Find a route to the destination.
- route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false)
+ route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false)
if err != nil {
return err
}
@@ -428,7 +408,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
+func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -439,7 +419,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
}
// Listen implements tcpip.Endpoint.Listen.
-func (*endpoint) Listen(backlog int) tcpip.Error {
+func (*endpoint) Listen(int) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
@@ -513,12 +493,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
return &tcpip.ErrUnknownProtocolOption{}
}
@@ -621,7 +601,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
combinedVV.Append(pkt.Data().ExtractVV())
packet.data = combinedVV
- packet.timestampNS = e.stack.Clock().NowNanoseconds()
+ packet.receivedAt = e.stack.Clock().Now()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 5d6f2709c..39669b445 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,11 +15,23 @@
package raw
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *rawPacket) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *rawPacket) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves rawPacket.data field.
func (p *rawPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 0f20d3856..8436d2cf0 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -41,7 +41,6 @@ go_library(
"protocol.go",
"rack.go",
"rcv.go",
- "rcv_state.go",
"reno.go",
"reno_recovery.go",
"sack.go",
@@ -53,7 +52,6 @@ go_library(
"segment_state.go",
"segment_unsafe.go",
"snd.go",
- "snd_state.go",
"tcp_endpoint_list.go",
"tcp_segment_list.go",
"timer.go",
@@ -134,6 +132,7 @@ go_test(
deps = [
"//pkg/sleep",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d4bd4e80e..d807b13b7 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -114,8 +114,8 @@ type listenContext struct {
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
-func timeStamp() uint32 {
- return uint32(time.Now().Unix()>>6) & tsMask
+func timeStamp(clock tcpip.Clock) uint32 {
+ return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Seconds()) >> 6 & tsMask
}
// newListenContext creates a new listen context.
@@ -171,7 +171,7 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc
// createCookie creates a SYN cookie for the given id and incoming sequence
// number.
func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
- ts := timeStamp()
+ ts := timeStamp(l.stack.Clock())
v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
v += (l.cookieHash(id, ts, 1) + data) & hashMask
return seqnum.Value(v)
@@ -181,7 +181,7 @@ func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Va
// sequence number. If it is, it also returns the data originally encoded in the
// cookie when createCookie was called.
func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
- ts := timeStamp()
+ ts := timeStamp(l.stack.Clock())
v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
cookieTS := v >> tsOffset
if ((ts - cookieTS) & tsMask) > maxTSDiff {
@@ -247,7 +247,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- isn := generateSecureISN(s.id, l.stack.Seed())
+ isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed())
ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
@@ -550,7 +550,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.rcvQueueInfo.rcvQueueMu.Lock()
rcvClosed := e.rcvQueueInfo.RcvClosed
e.rcvQueueInfo.rcvQueueMu.Unlock()
- if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
+ if rcvClosed || s.flags.Contains(header.TCPFlagSyn|header.TCPFlagAck) {
// If the endpoint is shutdown, reply with reset.
//
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
@@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
switch {
+ case s.flags.Contains(header.TCPFlagRst):
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+
case s.flags == header.TCPFlagSyn:
if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
@@ -591,7 +595,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
+ TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
@@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
return nil
- case (s.flags & header.TCPFlagAck) != 0:
+ case s.flags.Contains(header.TCPFlagAck):
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
mss: rcvdSynOptions.MSS,
})
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) {
+ s.incRef()
+ n.newSegmentWaker.Assert()
+ }
+
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
// the application is slow to accept or stops
@@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
default:
+ e.stack.Stats().DroppedPackets.Increment()
return nil
}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5e03e7715..e39d1623d 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -19,7 +19,6 @@ import (
"math"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -92,7 +91,7 @@ type handshake struct {
rcvWndScale int
// startTime is the time at which the first SYN/SYN-ACK was sent.
- startTime time.Time
+ startTime tcpip.MonotonicTime
// deferAccept if non-zero will drop the final ACK for a passive
// handshake till an ACK segment with data is received or the timeout is
@@ -147,21 +146,16 @@ func FindWndScale(wnd seqnum.Size) int {
// resetState resets the state of the handshake object such that it becomes
// ready for a new 3-way handshake.
func (h *handshake) resetState() {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
-
h.state = handshakeSynSent
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed())
}
// generateSecureISN generates a secure Initial Sequence number based on the
// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
-func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value {
isnHasher := jenkins.Sum32(seed)
isnHasher.Write([]byte(id.LocalAddress))
isnHasher.Write([]byte(id.RemoteAddress))
@@ -180,7 +174,7 @@ func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
//
// Which sort of guarantees that we won't reuse the ISN for a new
// connection for the same tuple for at least 274s.
- isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ isn := isnHasher.Sum32() + uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Nanoseconds()>>6)
return seqnum.Value(isn)
}
@@ -212,7 +206,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in
// response.
func (h *handshake) checkAck(s *segment) bool {
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
// RFC 793, page 36, states that a reset must be generated when
// the connection is in any non-synchronized state and an
// incoming segment acknowledges something not yet sent. The
@@ -230,8 +224,8 @@ func (h *handshake) checkAck(s *segment) bool {
func (h *handshake) synSentState(s *segment) tcpip.Error {
// RFC 793, page 37, states that in the SYN-SENT state, a reset is
// acceptable if the ack field acknowledges the SYN.
- if s.flagIsSet(header.TCPFlagRst) {
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ if s.flags.Contains(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
// RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
// was acceptable then signal the user "error: connection reset", drop
// the segment, enter CLOSED state, delete TCB, and return."
@@ -249,7 +243,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// We are in the SYN-SENT state. We only care about segments that have
// the SYN flag.
- if !s.flagIsSet(header.TCPFlagSyn) {
+ if !s.flags.Contains(header.TCPFlagSyn) {
return nil
}
@@ -270,7 +264,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// If this is a SYN ACK response, we only need to acknowledge the SYN
// and the handshake is completed.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
h.state = handshakeCompleted
h.ep.transitionToStateEstablishedLocked(h)
@@ -316,7 +310,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// synRcvdState handles a segment received when the TCP 3-way handshake is in
// the SYN-RCVD state.
func (h *handshake) synRcvdState(s *segment) tcpip.Error {
- if s.flagIsSet(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagRst) {
// RFC 793, page 37, states that in the SYN-RCVD state, a reset
// is acceptable if the sequence number is in the window.
if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
@@ -340,13 +334,13 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
return nil
}
- if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
+ if s.flags.Contains(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
// We received two SYN segments with different sequence
// numbers, so we reset this and restart the whole
// process, except that we don't reset the timer.
ack := s.sequenceNumber.Add(s.logicalLen())
seq := seqnum.Value(0)
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
seq = s.ackNumber
}
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
@@ -378,10 +372,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
// If deferAccept is not zero and this is a bare ACK and the
// timeout is not hit then drop the ACK.
- if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ if h.deferAccept != 0 && s.data.Size() == 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) < h.deferAccept {
h.acked = true
h.ep.stack.Stats().DroppedPackets.Increment()
return nil
@@ -412,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
h.ep.transitionToStateEstablishedLocked(h)
- // If the segment has data then requeue it for the receiver
- // to process it again once main loop is started.
- if s.data.Size() > 0 {
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) {
s.incRef()
- h.ep.enqueueSegment(s)
+ h.ep.newSegmentWaker.Assert()
}
return nil
}
@@ -426,7 +420,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
func (h *handshake) handleSegment(s *segment) tcpip.Error {
h.sndWnd = s.window
- if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
+ if !s.flags.Contains(header.TCPFlagSyn) && h.sndWndScale > 0 {
h.sndWnd <<= uint8(h.sndWndScale)
}
@@ -474,7 +468,7 @@ func (h *handshake) processSegments() tcpip.Error {
// start sends the first SYN/SYN-ACK. It does not block, even if link address
// resolution is required.
func (h *handshake) start() {
- h.startTime = time.Now()
+ h.startTime = h.ep.stack.Clock().NowMonotonic()
h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
@@ -527,7 +521,7 @@ func (h *handshake) complete() tcpip.Error {
defer s.Done()
// Initialize the resend timer.
- timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert)
+ timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert)
if err != nil {
return err
}
@@ -552,7 +546,7 @@ func (h *handshake) complete() tcpip.Error {
// The last is required to provide a way for the peer to complete
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
- if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ if h.active || !h.acked || h.deferAccept != 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) > h.deferAccept {
h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
@@ -608,15 +602,15 @@ func (h *handshake) complete() tcpip.Error {
type backoffTimer struct {
timeout time.Duration
maxTimeout time.Duration
- t *time.Timer
+ t tcpip.Timer
}
-func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) {
+func newBackoffTimer(clock tcpip.Clock, timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) {
if timeout > maxTimeout {
return nil, &tcpip.ErrTimeout{}
}
bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
- bt.t = time.AfterFunc(timeout, f)
+ bt.t = clock.AfterFunc(timeout, f)
return bt, nil
}
@@ -634,7 +628,7 @@ func (bt *backoffTimer) stop() {
}
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
- synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
+ synOpts := header.ParseSynOptions(s.options, s.flags.Contains(header.TCPFlagAck))
if synOpts.TS {
s.parsedOptions.TSVal = synOpts.TSVal
s.parsedOptions.TSEcr = synOpts.TSEcr
@@ -915,30 +909,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
return err
}
-func (e *endpoint) handleWrite() {
- e.sndQueueInfo.sndQueueMu.Lock()
- next := e.drainSendQueueLocked()
- e.sndQueueInfo.sndQueueMu.Unlock()
-
- e.sendData(next)
-}
-
-// Move packets from send queue to send list.
-//
-// Precondition: e.sndBufMu must be locked.
-func (e *endpoint) drainSendQueueLocked() *segment {
- first := e.sndQueueInfo.sndQueue.Front()
- if first != nil {
- e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue)
- e.sndQueueInfo.SndBufInQueue = 0
- }
- return first
-}
-
// Precondition: e.mu must be locked.
func (e *endpoint) sendData(next *segment) {
// Initialize the next segment to write if it's currently nil.
if e.snd.writeNext == nil {
+ if next == nil {
+ return
+ }
e.snd.writeNext = next
}
@@ -946,17 +923,6 @@ func (e *endpoint) sendData(next *segment) {
e.snd.sendData()
}
-func (e *endpoint) handleClose() {
- if !e.EndpointState().connected() {
- return
- }
- // Drain the send queue.
- e.handleWrite()
-
- // Mark send side as closed.
- e.snd.Closed = true
-}
-
// resetConnectionLocked puts the endpoint in an error state with the given
// error code and sends a RST if and only if the error is not ErrConnectionReset
// indicating that the connection is being reset due to receiving a RST. This
@@ -1136,7 +1102,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- if e.EndpointState().closed() {
+ if state := e.EndpointState(); state.closed() || state == StateTimeWait {
return nil
}
s := e.segmentQueue.dequeue()
@@ -1188,11 +1154,11 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error)
// the TCPEndpointState after the segment is processed.
defer e.probeSegmentLocked()
- if s.flagIsSet(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
return false, err
}
- } else if s.flagIsSet(header.TCPFlagSyn) {
+ } else if s.flags.Contains(header.TCPFlagSyn) {
// See: https://tools.ietf.org/html/rfc5961#section-4.1
// 1) If the SYN bit is set, irrespective of the sequence number, TCP
// MUST send an ACK (also referred to as challenge ACK) to the remote
@@ -1216,7 +1182,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error)
// should then rely on SYN retransmission from the remote end to
// re-establish the connection.
e.snd.maybeSendOutOfWindowAck(s)
- } else if s.flagIsSet(header.TCPFlagAck) {
+ } else if s.flags.Contains(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
s.window <<= e.snd.SndWndScale
@@ -1235,7 +1201,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error)
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
- state := e.state
+ state := e.EndpointState()
if state == StateClose {
// When we get into StateClose while processing from the queue,
// return immediately and let the protocolMainloop handle it.
@@ -1267,7 +1233,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error {
// If a userTimeout is set then abort the connection if it is
// exceeded.
- if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
+ if userTimeout != 0 && e.stack.Clock().NowMonotonic().Sub(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
e.keepalive.Unlock()
e.stack.Stats().TCP.EstablishedTimedout.Increment()
return &tcpip.ErrTimeout{}
@@ -1322,7 +1288,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// segments.
func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
e.mu.Lock()
- var closeTimer *time.Timer
+ var closeTimer tcpip.Timer
var closeWaker sleep.Waker
epilogue := func() {
@@ -1408,14 +1374,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
{
w: &e.sndQueueInfo.sndWaker,
f: func() tcpip.Error {
- e.handleWrite()
- return nil
- },
- },
- {
- w: &e.sndQueueInfo.sndCloseWaker,
- f: func() tcpip.Error {
- e.handleClose()
+ e.sendData(nil /* next */)
return nil
},
},
@@ -1480,11 +1439,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return &tcpip.ErrConnectionReset{}
}
- if n&notifyClose != 0 && closeTimer == nil {
- if e.EndpointState() == StateFinWait2 && e.closed {
+ if n&notifyClose != 0 && e.closed {
+ switch e.EndpointState() {
+ case StateEstablished:
+ // Perform full shutdown if the endpoint is still
+ // established. This can occur when notifyClose
+ // was asserted just before becoming established.
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ case StateFinWait2:
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ if closeTimer == nil {
+ closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ }
}
}
@@ -1721,7 +1688,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
var timeWaitWaker sleep.Waker
s.AddWaker(&timeWaitWaker, timeWaitDone)
- timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
defer timeWaitTimer.Stop()
for {
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index 962f1d687..6985194bb 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -41,7 +41,7 @@ type cubicState struct {
func newCubicCC(s *sender) *cubicState {
return &cubicState{
TCPCubicState: stack.TCPCubicState{
- T: time.Now(),
+ T: s.ep.stack.Clock().NowMonotonic(),
Beta: 0.7,
C: 0.4,
},
@@ -60,7 +60,7 @@ func (c *cubicState) enterCongestionAvoidance() {
// https://tools.ietf.org/html/rfc8312#section-4.8
if c.numCongestionEvents == 0 {
c.K = 0
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
c.WLastMax = c.WMax
c.WMax = float64(c.s.SndCwnd)
}
@@ -115,14 +115,15 @@ func (c *cubicState) cubicCwnd(t float64) float64 {
// getCwnd returns the current congestion window as computed by CUBIC.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
- elapsed := time.Since(c.T).Seconds()
+ elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T)
+ elapsedSeconds := elapsed.Seconds()
// Compute the window as per Cubic after 'elapsed' time
// since last congestion event.
- c.WC = c.cubicCwnd(elapsed - c.K)
+ c.WC = c.cubicCwnd(elapsedSeconds - c.K)
// Compute the TCP friendly estimate of the congestion window.
- c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds())
+ c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds())
// Make sure in the TCP friendly region CUBIC performs at least
// as well as Reno.
@@ -134,7 +135,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int
// In Concave/Convex region of CUBIC, calculate what CUBIC window
// will be after 1 RTT and use that to grow congestion window
// for every ack.
- tEst := (time.Since(c.T) + srtt).Seconds()
+ tEst := (elapsed + srtt).Seconds()
wtRtt := c.cubicCwnd(tEst - c.K)
// As per 4.3 for each received ACK cwnd must be incremented
// by (w_cubic(t+RTT) - cwnd/cwnd.
@@ -151,7 +152,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int
func (c *cubicState) HandleLossDetected() {
// See: https://tools.ietf.org/html/rfc8312#section-4.5
c.numCongestionEvents++
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
c.WLastMax = c.WMax
c.WMax = float64(c.s.SndCwnd)
@@ -162,7 +163,7 @@ func (c *cubicState) HandleLossDetected() {
// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
func (c *cubicState) HandleRTOExpired() {
// See: https://tools.ietf.org/html/rfc8312#section-4.6
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
c.numCongestionEvents = 0
c.WLastMax = c.WMax
c.WMax = float64(c.s.SndCwnd)
@@ -193,7 +194,7 @@ func (c *cubicState) fastConvergence() {
// PostRecovery implemements congestionControl.PostRecovery.
func (c *cubicState) PostRecovery() {
- c.T = time.Now()
+ c.T = c.s.ep.stack.Clock().NowMonotonic()
}
// reduceSlowStartThreshold returns new SsThresh as described in
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 512053a04..dff7cb89c 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -16,10 +16,11 @@ package tcp
import (
"encoding/binary"
+ "math/rand"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -141,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) {
// in-order.
type dispatcher struct {
processors []processor
- seed uint32
- wg sync.WaitGroup
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+ wg sync.WaitGroup
}
-func (d *dispatcher) init(nProcessors int) {
+func (d *dispatcher) init(rng *rand.Rand, nProcessors int) {
d.close()
d.wait()
d.processors = make([]processor, nProcessors)
- d.seed = generateRandUint32()
+ d.seed = rng.Uint32()
for i := range d.processors {
p := &d.processors[i]
p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
@@ -172,12 +174,11 @@ func (d *dispatcher) wait() {
d.wg.Wait()
}
-func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
- s := newIncomingSegment(id, pkt)
+ s := newIncomingSegment(id, clock, pkt)
if !s.parse(pkt.RXTransportChecksumValidated) {
- ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
s.decRef()
@@ -185,7 +186,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans
}
if !s.csumValid {
- ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.ChecksumErrors.Increment()
ep.stats.ReceiveErrors.ChecksumErrors.Increment()
s.decRef()
@@ -213,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans
d.selectProcessor(id).queueEndpoint(ep)
}
-func generateRandUint32() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- return binary.LittleEndian.Uint32(b)
-}
-
func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
var payload [4]byte
binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index f148d505d..5342aacfd 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -421,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
r.Reset(data)
nep.Write(&r, tcpip.WriteOptions{})
b = c.GetPacket()
- tcp = header.TCP(header.IPv4(b).Payload())
+ tcp = header.IPv4(b).Payload()
if string(tcp.Payload()) != data {
t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 50d39cbad..4acddc959 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -20,12 +20,12 @@ import (
"fmt"
"io"
"math"
+ "math/rand"
"runtime"
"strings"
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,19 +38,15 @@ import (
)
// EndpointState represents the state of a TCP endpoint.
-type EndpointState uint32
+type EndpointState tcpip.EndpointState
// Endpoint states. Note that are represented in a netstack-specific manner and
// may not be meaningful externally. Specifically, they need to be translated to
// Linux's representation for these states if presented to userspace.
const (
- // Endpoint states internal to netstack. These map to the TCP state CLOSED.
- StateInitial EndpointState = iota
- StateBound
- StateConnecting // Connect() called, but the initial SYN hasn't been sent.
- StateError
-
- // TCP protocol states.
+ _ EndpointState = iota
+ // TCP protocol states in sync with the definitions in
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13
StateEstablished
StateSynSent
StateSynRecv
@@ -62,6 +58,12 @@ const (
StateLastAck
StateListen
StateClosing
+
+ // Endpoint states internal to netstack.
+ StateInitial
+ StateBound
+ StateConnecting // Connect() called, but the initial SYN hasn't been sent.
+ StateError
)
const (
@@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool {
}
}
+// internal returns true when the state is netstack internal.
+func (s EndpointState) internal() bool {
+ switch s {
+ case StateInitial, StateBound, StateConnecting, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// handshake returns true when s is one of the states representing an endpoint
// in the middle of a TCP handshake.
func (s EndpointState) handshake() bool {
@@ -281,16 +293,9 @@ type sndQueueInfo struct {
sndQueueMu sync.Mutex `state:"nosave"`
stack.TCPSndBufState
- // sndQueue holds segments that are ready to be sent.
- sndQueue segmentList `state:"wait"`
-
- // sndWaker is used to signal the protocol goroutine when segments are
- // added to the `sndQueue`.
+ // sndWaker is used to signal the protocol goroutine when there may be
+ // segments that need to be sent.
sndWaker sleep.Waker `state:"manual"`
-
- // sndCloseWaker is used to notify the protocol goroutine when the send
- // side is closed.
- sndCloseWaker sleep.Waker `state:"manual"`
}
// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata.
@@ -422,12 +427,12 @@ type endpoint struct {
// state must be read/set using the EndpointState()/setEndpointState()
// methods.
- state EndpointState `state:".(EndpointState)"`
+ state uint32 `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
// endpoint state at restore time as the socket is moved to it's correct
// state.
- origEndpointState EndpointState `state:"nosave"`
+ origEndpointState uint32 `state:"nosave"`
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
@@ -468,7 +473,7 @@ type endpoint struct {
// recentTSTime is the unix time when we last updated
// TCPEndpointStateInner.RecentTS.
- recentTSTime time.Time `state:".(unixTime)"`
+ recentTSTime tcpip.MonotonicTime
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -626,7 +631,7 @@ type endpoint struct {
// lastOutOfWindowAckTime is the time at which the an ACK was sent in response
// to an out of window segment being received by this endpoint.
- lastOutOfWindowAckTime time.Time `state:".(unixTime)"`
+ lastOutOfWindowAckTime tcpip.MonotonicTime
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -747,7 +752,7 @@ func (e *endpoint) ResumeWork() {
//
// Precondition: e.mu must be held to call this method.
func (e *endpoint) setEndpointState(state EndpointState) {
- oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ oldstate := EndpointState(atomic.LoadUint32(&e.state))
switch state {
case StateEstablished:
e.stack.Stats().TCP.CurrentEstablished.Increment()
@@ -764,18 +769,18 @@ func (e *endpoint) setEndpointState(state EndpointState) {
e.stack.Stats().TCP.CurrentEstablished.Decrement()
}
}
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32(&e.state, uint32(state))
}
// EndpointState returns the current state of the endpoint.
func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ return EndpointState(atomic.LoadUint32(&e.state))
}
// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
e.RecentTS = recentTS
- e.recentTSTime = time.Now()
+ e.recentTSTime = e.stack.Clock().NowMonotonic()
}
// recentTimestamp returns the value of the recentTS field.
@@ -806,11 +811,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
},
sndQueueInfo: sndQueueInfo{
TCPSndBufState: stack.TCPSndBufState{
- SndMTU: int(math.MaxInt32),
+ SndMTU: math.MaxInt32,
},
},
waiterQueue: waiterQueue,
- state: StateInitial,
+ state: uint32(StateInitial),
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -870,9 +875,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset()
+ e.TSOffset = timeStampOffset(e.stack.Rand())
e.acceptCond = sync.NewCond(&e.acceptMu)
- e.keepalive.timer.init(&e.keepalive.waker)
+ e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
return e
}
@@ -1189,7 +1194,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- now := time.Now()
+ now := e.stack.Clock().NowMonotonic()
if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt {
e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied
e.rcvQueueInfo.rcvQueueMu.Unlock()
@@ -1544,12 +1549,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
// Add data to the send queue.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, v)
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
e.sndQueueInfo.SndBufUsed += len(v)
- e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v))
- e.sndQueueInfo.sndQueue.PushBack(s)
+ e.snd.writeList.PushBack(s)
- return e.drainSendQueueLocked(), len(v), nil
+ return s, len(v), nil
}()
// Return if either we didn't queue anything or if an error occurred while
// attempting to queue data.
@@ -1956,6 +1960,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption {
info := tcpip.TCPInfoOption{}
e.LockUser()
+ if state := e.EndpointState(); state.internal() {
+ info.State = tcpip.EndpointState(StateClose)
+ } else {
+ info.State = tcpip.EndpointState(state)
+ }
snd := e.snd
if snd != nil {
// We do not calculate RTT before sending the data packets. If
@@ -2198,7 +2207,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
@@ -2224,7 +2233,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
// less than 1 second has elapsed since its recentTS was updated then
// we cannot reuse the port.
- if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ if tcpEP.EndpointState() != StateTimeWait || e.stack.Clock().NowMonotonic().Sub(tcpEP.recentTSTime) < 1*time.Second {
tcpEP.UnlockUser()
return false, nil
}
@@ -2245,7 +2254,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
BindToDevice: bindToDevice,
Dest: addr,
}
- if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2297,7 +2306,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// connection setting here.
if !handshake {
e.segmentQueue.mu.Lock()
- for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} {
+ for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.TransportEndpointInfo.ID
e.sndQueueInfo.sndWaker.Assert()
@@ -2355,6 +2364,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
+ // Wake up any readers that maybe waiting for the stream to become
+ // readable.
+ e.waiterQueue.Notify(waiter.ReadableEvents)
}
// Close for write.
@@ -2370,13 +2382,21 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
}
// Queue fin segment.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil)
- e.sndQueueInfo.sndQueue.PushBack(s)
- e.sndQueueInfo.SndBufInQueue++
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil)
+ e.snd.writeList.PushBack(s)
// Mark endpoint as closed.
e.sndQueueInfo.SndClosed = true
e.sndQueueInfo.sndQueueMu.Unlock()
- e.handleClose()
+
+ // Drain the send queue.
+ e.sendData(s)
+
+ // Mark send side as closed.
+ e.snd.Closed = true
+
+ // Wake up any writers that maybe waiting for the stream to become
+ // writable.
+ e.waiterQueue.Notify(waiter.WritableEvents)
}
return nil
@@ -2581,7 +2601,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) {
id := e.TransportEndpointInfo.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2731,7 +2751,7 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1
+ notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1
e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
@@ -2848,23 +2868,20 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(time.Now(), e.TSOffset)
+ return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
// not inlined above as it's used when SYN cookies are in use and endpoint
// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(curTime time.Time, offset uint32) uint32 {
- return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset
+func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
+ d := curTime.Sub(tcpip.MonotonicTime{})
+ return uint32(d.Milliseconds()) + offset
}
// timeStampOffset returns a randomized timestamp offset to be used when sending
// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset() uint32 {
- b := make([]byte, 4)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
+func timeStampOffset(rng *rand.Rand) uint32 {
// Initialize a random tsOffset that will be added to the recentTS
// everytime the timestamp is sent when the Timestamp option is enabled.
//
@@ -2874,7 +2891,7 @@ func timeStampOffset() uint32 {
// NOTE: This is not completely to spec as normally this should be
// initialized in a manner analogous to how sequence numbers are
// randomized per connection basis. But for now this is sufficient.
- return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+ return rng.Uint32()
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
@@ -2909,7 +2926,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
s := stack.TCPEndpointState{
TCPEndpointStateInner: e.TCPEndpointStateInner,
ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID),
- SegTime: time.Now(),
+ SegTime: e.stack.Clock().NowMonotonic(),
Receiver: e.rcv.TCPReceiverState,
Sender: e.snd.TCPSenderState,
}
@@ -2937,7 +2954,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
if cubic, ok := e.snd.cc.(*cubicState); ok {
s.Sender.Cubic = cubic.TCPCubicState
- s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T)
+ s.Sender.Cubic.TimeSinceLastCongestion = e.stack.Clock().NowMonotonic().Sub(s.Sender.Cubic.T)
}
s.Sender.RACKState = e.snd.rc.TCPRACKState
@@ -3029,14 +3046,16 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption {
// allowOutOfWindowAck returns true if an out-of-window ACK can be sent now.
func (e *endpoint) allowOutOfWindowAck() bool {
- var limit stack.TCPInvalidRateLimitOption
- if err := e.stack.Option(&limit); err != nil {
- panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err))
- }
+ now := e.stack.Clock().NowMonotonic()
- now := time.Now()
- if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) {
- return false
+ if e.lastOutOfWindowAckTime != (tcpip.MonotonicTime{}) {
+ var limit stack.TCPInvalidRateLimitOption
+ if err := e.stack.Option(&limit); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err))
+ }
+ if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) {
+ return false
+ }
}
e.lastOutOfWindowAckTime = now
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 6e9777fe4..952ccacdd 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -154,20 +154,25 @@ func (e *endpoint) afterLoad() {
e.origEndpointState = e.state
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
- e.state = StateInitial
+ e.state = uint32(StateInitial)
// Condition variables and mutexs are not S/R'ed so reinitialize
// acceptCond with e.acceptMu.
e.acceptCond = sync.NewCond(&e.acceptMu)
- e.keepalive.timer.init(&e.keepalive.waker)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.keepalive.timer.init(s.Clock(), &e.keepalive.waker)
+ if snd := e.snd; snd != nil {
+ snd.resendTimer.init(s.Clock(), &snd.resendWaker)
+ snd.reorderTimer.init(s.Clock(), &snd.reorderWaker)
+ snd.probeTimer.init(s.Clock(), &snd.probeWaker)
+ }
e.stack = s
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
- epState := e.origEndpointState
+ epState := EndpointState(e.origEndpointState)
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss tcpip.TCPSendBufferSizeRangeOption
@@ -281,32 +286,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
}()
case epState == StateClose:
e.isPortReserved = false
- e.state = StateClose
+ e.state = uint32(StateClose)
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
case epState == StateError:
- e.state = StateError
+ e.state = uint32(StateError)
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
}
-
-// saveRecentTSTime is invoked by stateify.
-func (e *endpoint) saveRecentTSTime() unixTime {
- return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
-}
-
-// loadRecentTSTime is invoked by stateify.
-func (e *endpoint) loadRecentTSTime(unix unixTime) {
- e.recentTSTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveLastOutOfWindowAckTime is invoked by stateify.
-func (e *endpoint) saveLastOutOfWindowAckTime() unixTime {
- return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()}
-}
-
-// loadLastOutOfWindowAckTime is invoked by stateify.
-func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) {
- e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2f9fe7ee0..65c86823a 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -65,7 +65,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
- s := newIncomingSegment(id, pkt)
+ s := newIncomingSegment(id, f.stack.Clock(), pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index a3d1aa1a3..2fc282e73 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -131,7 +131,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- p.dispatcher.queuePacket(ep, id, pkt)
+ p.dispatcher.queuePacket(ep, id, p.stack.Clock(), pkt)
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
@@ -142,14 +142,14 @@ func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEnd
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
- s := newIncomingSegment(id, pkt)
+ s := newIncomingSegment(id, p.stack.Clock(), pkt)
defer s.decRef()
if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid {
return stack.UnknownDestinationPacketMalformed
}
- if !s.flagIsSet(header.TCPFlagRst) {
+ if !s.flags.Contains(header.TCPFlagRst) {
replyWithReset(p.stack, s, stack.DefaultTOS, 0)
}
@@ -181,7 +181,7 @@ func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error {
// reset has sequence number zero and the ACK field is set to the sum
// of the sequence number and segment length of the incoming segment.
// The connection remains in the CLOSED state.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flags.Contains(header.TCPFlagAck) {
seq = s.ackNumber
} else {
flags |= header.TCPFlagAck
@@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er
case *tcpip.TCPTimeWaitReuseOption:
p.mu.RLock()
- *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
+ *v = p.timeWaitReuse
p.mu.RUnlock()
return nil
@@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
// TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection.
recovery: 0,
}
- p.dispatcher.init(runtime.GOMAXPROCS(0))
+ p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 9e332dcf7..0da4eafaa 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -79,7 +79,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// update will update the RACK related fields when an ACK has been received.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
- rtt := time.Now().Sub(seg.xmitTime)
+ rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime)
tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
@@ -115,7 +115,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// ending sequence number of the packet which has been acknowledged
// most recently.
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) {
rc.XmitTime = seg.xmitTime
rc.EndSequence = endSeq
}
@@ -174,7 +174,7 @@ func (s *sender) schedulePTO() {
}
s.rtt.Unlock()
- now := time.Now()
+ now := s.ep.stack.Clock().NowMonotonic()
if s.resendTimer.enabled() {
if now.Add(pto).After(s.resendTimer.target) {
pto = s.resendTimer.target.Sub(now)
@@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
// been observed RACK uses reo_wnd of zero during loss recovery, in order to
// retransmit quickly, or when the number of DUPACKs exceeds the classic
// DUPACKthreshold.
-func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
+func (rc *rackControl) updateRACKReorderWindow() {
dsackSeen := rc.DSACKSeen
snd := rc.snd
@@ -352,7 +352,7 @@ func (rc *rackControl) exitRecovery() {
// detectLoss marks the segment as lost if the reordering window has elapsed
// and the ACK is not received. It will also arm the reorder timer.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5.
-func (rc *rackControl) detectLoss(rcvTime time.Time) int {
+func (rc *rackControl) detectLoss(rcvTime tcpip.MonotonicTime) int {
var timeout time.Duration
numLost := 0
for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() {
@@ -366,7 +366,7 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int {
}
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) {
timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd
if timeRemaining <= 0 {
seg.lost = true
@@ -392,7 +392,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error {
return nil
}
- numLost := rc.detectLoss(time.Now())
+ numLost := rc.detectLoss(rc.snd.ep.stack.Clock().NowMonotonic())
if numLost == 0 {
return nil
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 133371455..9ce8fcae9 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -17,7 +17,6 @@ package tcp
import (
"container/heap"
"math"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -50,7 +49,7 @@ type receiver struct {
pendingRcvdSegments segmentHeap
// Time when the last ack was received.
- lastRcvdAckTime time.Time `state:".(unixTime)"`
+ lastRcvdAckTime tcpip.MonotonicTime
}
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
@@ -63,7 +62,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
},
rcvWnd: rcvWnd,
rcvWUP: irs + 1,
- lastRcvdAckTime: time.Now(),
+ lastRcvdAckTime: ep.stack.Clock().NowMonotonic(),
}
}
@@ -137,9 +136,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// rcvWUP RcvNxt RcvAcc new RcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
+ if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow {
// If the new window moves the right edge, then update RcvAcc.
- r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd))
+ r.RcvAcc = r.RcvNxt.Add(newWnd)
} else {
if newWnd == 0 {
// newWnd is zero but we can't advertise a zero as it would cause window
@@ -245,7 +244,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
TrimSACKBlockList(&r.ep.sack, r.RcvNxt)
// Handle FIN or FIN-ACK.
- if s.flagIsSet(header.TCPFlagFin) {
+ if s.flags.Contains(header.TCPFlagFin) {
r.RcvNxt++
// Send ACK immediately.
@@ -261,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
// FIN-ACK, transition to TIME-WAIT.
r.ep.setEndpointState(StateTimeWait)
} else {
@@ -296,7 +295,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
+ if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
switch r.ep.EndpointState() {
case StateFinWait1:
r.ep.setEndpointState(StateFinWait2)
@@ -325,9 +324,9 @@ func (r *receiver) updateRTT() {
// is first acknowledged and the receipt of data that is at least one
// window beyond the sequence number that was acknowledged.
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
- if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() {
+ if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime == (tcpip.MonotonicTime{}) {
// New measurement.
- r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic()
r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
@@ -336,14 +335,14 @@ func (r *receiver) updateRTT() {
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime)
+ rtt := r.ep.stack.Clock().NowMonotonic().Sub(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime)
// We only store the minimum observed RTT here as this is only used in
// absence of a SRTT available from either timestamps or a sender
// measurement of RTT.
if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT {
r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt
}
- r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic()
r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
}
@@ -424,7 +423,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// while the FIN is considered to occur after
// the last actual data octet in a segment in
// which it occurs.
- if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) {
+ if closed && (!s.flags.Contains(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) {
return true, &tcpip.ErrConnectionAborted{}
}
}
@@ -467,11 +466,11 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
}
// Store the time of the last ack.
- r.lastRcvdAckTime = time.Now()
+ r.lastRcvdAckTime = r.ep.stack.Clock().NowMonotonic()
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
- if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
+ if segLen > 0 || s.flags.Contains(header.TCPFlagFin) {
// We only store the segment if it's within our buffer size limit.
//
// Only use 75% of the receive buffer queue for out-of-order
@@ -539,7 +538,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
//
// As we do not yet support PAWS, we are being conservative in ignoring
// RSTs by default.
- if s.flagIsSet(header.TCPFlagRst) {
+ if s.flags.Contains(header.TCPFlagRst) {
return false, false
}
@@ -559,13 +558,12 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// (2) returns to TIME-WAIT state if the SYN turns out
// to be an old duplicate".
- if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
-
+ if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
return false, true
}
// Drop the segment if it does not contain an ACK.
- if !s.flagIsSet(header.TCPFlagAck) {
+ if !s.flags.Contains(header.TCPFlagAck) {
return false, false
}
@@ -574,7 +572,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq)
}
- if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ if segSeq.Add(1) == r.RcvNxt && s.flags.Contains(header.TCPFlagFin) {
// If it's a FIN-ACK then resetTimeWait and send an ACK, as it
// indicates our final ACK could have been lost.
r.ep.snd.sendAck()
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 7e5ba6ef7..ca78c96f2 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -17,7 +17,6 @@ package tcp
import (
"fmt"
"sync/atomic"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -73,9 +72,9 @@ type segment struct {
parsedOptions header.TCPOptions
options []byte `state:".([]byte)"`
hasNewSACKInfo bool
- rcvdTime time.Time `state:".(unixTime)"`
+ rcvdTime tcpip.MonotonicTime
// xmitTime is the last transmit time of this segment.
- xmitTime time.Time `state:".(unixTime)"`
+ xmitTime tcpip.MonotonicTime
xmitCount uint32
// acked indicates if the segment has already been SACKed.
@@ -88,7 +87,7 @@ type segment struct {
lost bool
}
-func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) *segment {
netHdr := pkt.Network()
s := &segment{
refCnt: 1,
@@ -100,17 +99,17 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
}
s.data = pkt.Data().ExtractVV().Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
- s.rcvdTime = time.Now()
+ s.rcvdTime = clock.NowMonotonic()
s.dataMemSize = s.data.Size()
return s
}
-func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
+func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
}
- s.rcvdTime = time.Now()
+ s.rcvdTime = clock.NowMonotonic()
if len(v) != 0 {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
@@ -149,16 +148,6 @@ func (s *segment) merge(oth *segment) {
oth.dataMemSize = oth.data.Size()
}
-// flagIsSet checks if at least one flag in flags is set in s.flags.
-func (s *segment) flagIsSet(flags header.TCPFlags) bool {
- return s.flags&flags != 0
-}
-
-// flagsAreSet checks if all flags in flags are set in s.flags.
-func (s *segment) flagsAreSet(flags header.TCPFlags) bool {
- return s.flags&flags == flags
-}
-
// setOwner sets the owning endpoint for this segment. Its required
// to be called to ensure memory accounting for receive/send buffer
// queues is done properly.
@@ -198,10 +187,10 @@ func (s *segment) incRef() {
// as the data length plus one for each of the SYN and FIN bits set.
func (s *segment) logicalLen() seqnum.Size {
l := seqnum.Size(s.data.Size())
- if s.flagIsSet(header.TCPFlagSyn) {
+ if s.flags.Contains(header.TCPFlagSyn) {
l++
}
- if s.flagIsSet(header.TCPFlagFin) {
+ if s.flags.Contains(header.TCPFlagFin) {
l++
}
return l
@@ -243,7 +232,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool {
return false
}
- s.options = []byte(s.hdr[header.TCPMinimumSize:])
+ s.options = s.hdr[header.TCPMinimumSize:]
s.parsedOptions = header.ParseTCPOptions(s.options)
if skipChecksumValidation {
s.csumValid = true
@@ -262,5 +251,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool {
// sackBlock returns a header.SACKBlock that represents this segment.
func (s *segment) sackBlock() header.SACKBlock {
- return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())}
+ return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())}
}
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7422d8c02..dcfa80f95 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -15,8 +15,6 @@
package tcp
import (
- "time"
-
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -55,23 +53,3 @@ func (s *segment) loadOptions(options []byte) {
// allocated so there is no cost here.
s.options = options
}
-
-// saveRcvdTime is invoked by stateify.
-func (s *segment) saveRcvdTime() unixTime {
- return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
-}
-
-// loadRcvdTime is invoked by stateify.
-func (s *segment) loadRcvdTime(unix unixTime) {
- s.rcvdTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveXmitTime is invoked by stateify.
-func (s *segment) saveXmitTime() unixTime {
- return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
-}
-
-// loadXmitTime is invoked by stateify.
-func (s *segment) loadXmitTime(unix unixTime) {
- s.rcvdTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
index 486016fc0..2e6ea06f5 100644
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -19,6 +19,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -39,10 +40,11 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW
}
func TestSegmentMerge(t *testing.T) {
+ var clock faketime.NullClock
id := stack.TransportEndpointID{}
- seg1 := newOutgoingSegment(id, buffer.NewView(10))
+ seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10))
defer seg1.decRef()
- seg2 := newOutgoingSegment(id, buffer.NewView(20))
+ seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20))
defer seg2.decRef()
checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index f43e86677..72d58dcff 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -94,7 +94,7 @@ type sender struct {
// firstRetransmittedSegXmitTime is the original transmit time of
// the first segment that was retransmitted due to RTO expiration.
- firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
+ firstRetransmittedSegXmitTime tcpip.MonotonicTime
// zeroWindowProbing is set if the sender is currently probing
// for zero receive window.
@@ -169,7 +169,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
SndUna: iss + 1,
SndNxt: iss + 1,
RTTMeasureSeqNum: iss + 1,
- LastSendTime: time.Now(),
+ LastSendTime: ep.stack.Clock().NowMonotonic(),
MaxPayloadSize: maxPayloadSize,
MaxSentAck: irs + 1,
FastRecovery: stack.TCPFastRecoveryState{
@@ -197,9 +197,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.SndWndScale = uint8(sndWndScale)
}
- s.resendTimer.init(&s.resendWaker)
- s.reorderTimer.init(&s.reorderWaker)
- s.probeTimer.init(&s.probeWaker)
+ s.resendTimer.init(s.ep.stack.Clock(), &s.resendWaker)
+ s.reorderTimer.init(s.ep.stack.Clock(), &s.reorderWaker)
+ s.probeTimer.init(s.ep.stack.Clock(), &s.probeWaker)
s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
@@ -441,7 +441,7 @@ func (s *sender) retransmitTimerExpired() bool {
// timeout since the first retransmission.
uto := s.ep.userTimeout
- if s.firstRetransmittedSegXmitTime.IsZero() {
+ if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) {
// We store the original xmitTime of the segment that we are
// about to retransmit as the retransmission time. This is
// required as by the time the retransmitTimer has expired the
@@ -450,7 +450,7 @@ func (s *sender) retransmitTimerExpired() bool {
s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime
}
- elapsed := time.Since(s.firstRetransmittedSegXmitTime)
+ elapsed := s.ep.stack.Clock().NowMonotonic().Sub(s.firstRetransmittedSegXmitTime)
remaining := s.maxRTO
if uto != 0 {
// Cap to the user specified timeout if one is specified.
@@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// 'S2' that meets the following 3 criteria for determinig
// loss, the sequence range of one segment of up to SMSS
// octects starting with S2 MUST be returned.
- if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) {
+ if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) {
// NextSeg():
//
// (1.a) S2 is greater than HighRxt
@@ -866,8 +866,8 @@ func (s *sender) enableZeroWindowProbing() {
// We piggyback the probing on the retransmit timer with the
// current retranmission interval, as we may start probing while
// segment retransmissions.
- if s.firstRetransmittedSegXmitTime.IsZero() {
- s.firstRetransmittedSegXmitTime = time.Now()
+ if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) {
+ s.firstRetransmittedSegXmitTime = s.ep.stack.Clock().NowMonotonic()
}
s.resendTimer.enable(s.RTO)
}
@@ -875,7 +875,7 @@ func (s *sender) enableZeroWindowProbing() {
func (s *sender) disableZeroWindowProbing() {
s.zeroWindowProbing = false
s.unackZeroWindowProbes = 0
- s.firstRetransmittedSegXmitTime = time.Time{}
+ s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{}
s.resendTimer.disable()
}
@@ -925,7 +925,7 @@ func (s *sender) sendData() {
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
- if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO {
+ if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && s.ep.stack.Clock().NowMonotonic().Sub(s.LastSendTime) > s.RTO {
if s.SndCwnd > InitialCwnd {
s.SndCwnd = InitialCwnd
}
@@ -1024,7 +1024,7 @@ func (s *sender) SetPipe() {
if segEnd.LessThan(endSeq) {
endSeq = segEnd
}
- sb := header.SACKBlock{startSeq, endSeq}
+ sb := header.SACKBlock{Start: startSeq, End: endSeq}
// SetPipe():
//
// After initializing pipe to zero, the following steps are
@@ -1132,7 +1132,7 @@ func (s *sender) isDupAck(seg *segment) bool {
// (b) The incoming acknowledgment carries no data.
seg.logicalLen() == 0 &&
// (c) The SYN and FIN bits are both off.
- !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) &&
+ !seg.flags.Intersects(header.TCPFlagFin|header.TCPFlagSyn) &&
// (d) the ACK number is equal to the greatest acknowledgment received on
// the given connection (TCP.UNA from RFC793).
seg.ackNumber == s.SndUna &&
@@ -1234,7 +1234,7 @@ func checkDSACK(rcvdSeg *segment) bool {
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
- s.updateRTO(time.Now().Sub(s.RTTMeasureTime))
+ s.updateRTO(s.ep.stack.Clock().NowMonotonic().Sub(s.RTTMeasureTime))
s.RTTMeasureSeqNum = s.SndNxt
}
@@ -1444,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
if s.SndUna == s.SndNxt {
s.Outstanding = 0
// Reset firstRetransmittedSegXmitTime to the zero value.
- s.firstRetransmittedSegXmitTime = time.Time{}
+ s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{}
s.resendTimer.disable()
s.probeTimer.disable()
}
@@ -1455,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// * Upon receiving an ACK:
// * Step 4: Update RACK reordering window
- s.rc.updateRACKReorderWindow(rcvdSeg)
+ s.rc.updateRACKReorderWindow()
// After the reorder window is calculated, detect any loss by checking
// if the time elapsed after the segments are sent is greater than the
@@ -1502,7 +1502,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
}
}
- seg.xmitTime = time.Now()
+ seg.xmitTime = s.ep.stack.Clock().NowMonotonic()
seg.xmitCount++
seg.lost = false
err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
@@ -1527,7 +1527,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
- s.LastSendTime = time.Now()
+ s.LastSendTime = s.ep.stack.Clock().NowMonotonic()
if seq == s.RTTMeasureSeqNum {
s.RTTMeasureTime = s.LastSendTime
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
deleted file mode 100644
index 2f805d8ce..000000000
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp
-
-import (
- "time"
-)
-
-// +stateify savable
-type unixTime struct {
- second int64
- nano int64
-}
-
-// afterLoad is invoked by stateify.
-func (s *sender) afterLoad() {
- s.resendTimer.init(&s.resendWaker)
- s.reorderTimer.init(&s.reorderWaker)
- s.probeTimer.init(&s.probeWaker)
-}
-
-// saveFirstRetransmittedSegXmitTime is invoked by stateify.
-func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime {
- return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()}
-}
-
-// loadFirstRetransmittedSegXmitTime is invoked by stateify.
-func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) {
- s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index c58361bc1..d6cf786a1 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -38,7 +38,7 @@ const (
func setStackRACKPermitted(t *testing.T, c *context.Context) {
t.Helper()
- opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection)
+ opt := tcpip.TCPRACKLossDetection
if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err)
}
@@ -50,7 +50,7 @@ func TestRACKUpdate(t *testing.T) {
c := context.New(t, uint32(mtu))
defer c.Cleanup()
- var xmitTime time.Time
+ var xmitTime tcpip.MonotonicTime
probeDone := make(chan struct{})
c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
// Validate that the endpoint Sender.RACKState is what we expect.
@@ -79,7 +79,7 @@ func TestRACKUpdate(t *testing.T) {
}
// Write the data.
- xmitTime = time.Now()
+ xmitTime = c.Stack().Clock().NowMonotonic()
var r bytes.Reader
r.Reset(data)
if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9916182e3..71c4aa85d 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3451,17 +3451,13 @@ loop:
for {
switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
case *tcpip.ErrWouldBlock:
- select {
- case <-ch:
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
- }
- break loop
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for reset to arrive")
+ <-ch
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
}
+ break loop
case *tcpip.ErrConnectionReset:
break loop
default:
@@ -3472,14 +3468,27 @@ loop:
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
- if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
+
+ checkValid := func() []error {
+ var errors []error
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got))
+ }
+ return errors
}
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
}
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ for _, err := range checkValid() {
+ t.Error(err)
}
}
@@ -4259,7 +4268,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(iss),
+ SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -5335,7 +5344,7 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(next-1)),
+ checker.TCPSeqNum(next-1),
checker.TCPAckNum(uint32(iss)),
checker.TCPFlags(header.TCPFlagAck),
),
@@ -5360,12 +5369,7 @@ func TestKeepalive(t *testing.T) {
})
checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(next)),
- checker.TCPAckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
+ checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)),
)
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
@@ -5507,7 +5511,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now execute send one more SYN. The stack should not respond as the backlog
// is full at this point.
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + uint16(lastPortOffset),
+ SrcPort: context.TestPort + lastPortOffset,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
@@ -5884,7 +5888,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
r.Reset(data)
newEP.Write(&r, tcpip.WriteOptions{})
pkt := c.GetPacket()
- tcp = header.TCP(header.IPv4(pkt).Payload())
+ tcp = header.IPv4(pkt).Payload()
if string(tcp.Payload()) != data {
t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
@@ -6073,6 +6077,11 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// complete the connection to test that the large SEQ num
// did not change the state from SYN-RCVD.
+ // Get setup to be notified about connection establishment.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.ReadableEvents)
+ defer c.WQ.EventUnregister(&we)
+
// Send ACK to move to ESTABLISHED state.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -6083,32 +6092,12 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
+ <-ch
newEP, _, err := c.EP.Accept(nil)
- switch err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- default:
+ if err != nil {
t.Fatalf("Accept failed: %s", err)
}
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
var r strings.Reader
@@ -6118,7 +6107,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
}
pkt := c.GetPacket()
- tcpHdr = header.TCP(header.IPv4(pkt).Payload())
+ tcpHdr = header.IPv4(pkt).Payload()
if string(tcpHdr.Payload()) != data {
t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
}
@@ -6214,12 +6203,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
RcvWnd: 30000,
})
- time.Sleep(50 * time.Millisecond)
- if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
+ checkValid := func() []error {
+ var errors []error
+ if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want))
+ }
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want))
+ }
+ return errors
}
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
+ }
+ for _, err := range checkValid() {
+ t.Error(err)
+ }
+ if t.Failed() {
+ t.FailNow()
}
we, ch := waiter.NewChannelEntry(nil)
@@ -6230,19 +6233,62 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
_, _, err = c.EP.Accept(nil)
if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
+ <-ch
+ _, _, err = c.EP.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
}
}
}
+func TestListenDropIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ c.Create(-1 /*epRcvBuf*/)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Listen(1 /*backlog*/); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ initialDropped := stats.DroppedPackets.Value()
+
+ // Send RST, FIN segments, that are expected to be dropped by the listener.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ })
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin,
+ })
+
+ // To ensure that the RST, FIN sent earlier are indeed received and ignored
+ // by the listener, send a SYN and wait for the SYN to be ACKd.
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ })
+ checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs)+1),
+ ))
+
+ if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want {
+ t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
func TestEndpointBindListenAcceptState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -6375,7 +6421,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Allocate a large enough payload for the test.
payloadSize := receiveBufferSize * 2
- b := make([]byte, int(payloadSize))
+ b := make([]byte, payloadSize)
worker := (c.EP).(interface {
StopWork()
@@ -6429,7 +6475,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// ack, 1 for the non-zero window
p := c.GetPacket()
checker.IPv4(t, p, checker.TCP(
- checker.TCPAckNum(uint32(wantAckNum)),
+ checker.TCPAckNum(wantAckNum),
func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
@@ -6484,14 +6530,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
- tsVal := uint32(rawEP.TSVal)
+ tsVal := rawEP.TSVal
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
scaleRcvWnd := func(rcvWnd int) uint16 {
- return uint16(rcvWnd >> uint16(c.WindowScale))
+ return uint16(rcvWnd >> c.WindowScale)
}
// Allocate a large array to send to the endpoint.
b := make([]byte, receiveBufferSize*48)
@@ -6619,19 +6665,16 @@ func TestDelayEnabled(t *testing.T) {
defer c.Cleanup()
checkDelayOption(t, c, false, false) // Delay is disabled by default.
- for _, v := range []struct {
- delayEnabled tcpip.TCPDelayEnabled
- wantDelayOption bool
- }{
- {delayEnabled: false, wantDelayOption: false},
- {delayEnabled: true, wantDelayOption: true},
- } {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err)
- }
- checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
+ for _, delayEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ opt := tcpip.TCPDelayEnabled(delayEnabled)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err)
+ }
+ checkDelayOption(t, c, opt, delayEnabled)
+ })
}
}
@@ -7042,7 +7085,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Receive the SYN-ACK reply.
b = c.GetPacket()
- tcpHdr = header.TCP(header.IPv4(b).Payload())
+ tcpHdr = header.IPv4(b).Payload()
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
ackHeaders = &context.Headers{
@@ -7443,7 +7486,7 @@ func TestTCPUserTimeout(t *testing.T) {
select {
case <-notifyCh:
case <-time.After(2 * initRTO):
- t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout)
+ t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout)
}
// No packet should be received as the connection should be silently
@@ -7467,7 +7510,7 @@ func TestTCPUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(next)),
+ checker.TCPSeqNum(next),
checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
@@ -7545,7 +7588,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: iss,
- AckNum: seqnum.Value(c.IRS + 1),
+ AckNum: c.IRS + 1,
RcvWnd: 30000,
})
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index 38a335840..5645c772e 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -15,21 +15,29 @@
package tcp
import (
+ "math"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
type timerState int
const (
+ // The timer is disabled.
timerStateDisabled timerState = iota
+ // The timer is enabled, but the clock timer may be set to an earlier
+ // expiration time due to a previous orphaned state.
timerStateEnabled
+ // The timer is disabled, but the clock timer is enabled, which means that
+ // it will cause a spurious wakeup unless the timer is enabled before the
+ // clock timer fires.
timerStateOrphaned
)
// timer is a timer implementation that reduces the interactions with the
-// runtime timer infrastructure by letting timers run (and potentially
+// clock timer infrastructure by letting timers run (and potentially
// eventually expire) even if they are stopped. It makes it cheaper to
// disable/reenable timers at the expense of spurious wakes. This is useful for
// cases when the same timer is disabled/reenabled repeatedly with relatively
@@ -39,44 +47,37 @@ const (
// (currently at least 200ms), and get disabled when acks are received, and
// reenabled when new pending segments are sent.
//
-// It is advantageous to avoid interacting with the runtime because it acquires
+// It is advantageous to avoid interacting with the clock because it acquires
// a global mutex and performs O(log n) operations, where n is the global number
// of timers, whenever a timer is enabled or disabled, and may make a syscall.
//
// This struct is thread-compatible.
type timer struct {
- // state is the current state of the timer, it can be one of the
- // following values:
- // disabled - the timer is disabled.
- // orphaned - the timer is disabled, but the runtime timer is
- // enabled, which means that it will evetually cause a
- // spurious wake (unless it gets enabled again before
- // then).
- // enabled - the timer is enabled, but the runtime timer may be set
- // to an earlier expiration time due to a previous
- // orphaned state.
state timerState
+ clock tcpip.Clock
+
// target is the expiration time of the current timer. It is only
// meaningful in the enabled state.
- target time.Time
+ target tcpip.MonotonicTime
- // runtimeTarget is the expiration time of the runtime timer. It is
+ // clockTarget is the expiration time of the clock timer. It is
// meaningful in the enabled and orphaned states.
- runtimeTarget time.Time
+ clockTarget tcpip.MonotonicTime
- // timer is the runtime timer used to wait on.
- timer *time.Timer
+ // timer is the clock timer used to wait on.
+ timer tcpip.Timer
}
// init initializes the timer. Once it expires, it the given waker will be
// asserted.
-func (t *timer) init(w *sleep.Waker) {
+func (t *timer) init(clock tcpip.Clock, w *sleep.Waker) {
t.state = timerStateDisabled
+ t.clock = clock
- // Initialize a runtime timer that will assert the waker, then
+ // Initialize a clock timer that will assert the waker, then
// immediately stop it.
- t.timer = time.AfterFunc(time.Hour, func() {
+ t.timer = t.clock.AfterFunc(math.MaxInt64, func() {
w.Assert()
})
t.timer.Stop()
@@ -106,9 +107,9 @@ func (t *timer) checkExpiration() bool {
// The timer is enabled, but it may have expired early. Check if that's
// the case, and if so, reset the runtime timer to the correct time.
- now := time.Now()
+ now := t.clock.NowMonotonic()
if now.Before(t.target) {
- t.runtimeTarget = t.target
+ t.clockTarget = t.target
t.timer.Reset(t.target.Sub(now))
return false
}
@@ -134,11 +135,11 @@ func (t *timer) enabled() bool {
// enable enables the timer, programming the runtime timer if necessary.
func (t *timer) enable(d time.Duration) {
- t.target = time.Now().Add(d)
+ t.target = t.clock.NowMonotonic().Add(d)
// Check if we need to set the runtime timer.
- if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
- t.runtimeTarget = t.target
+ if t.state == timerStateDisabled || t.target.Before(t.clockTarget) {
+ t.clockTarget = t.target
t.timer.Reset(d)
}
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
index dbd6dff54..479752de7 100644
--- a/pkg/tcpip/transport/tcp/timer_test.go
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
func TestCleanup(t *testing.T) {
@@ -27,9 +28,11 @@ func TestCleanup(t *testing.T) {
isAssertedTimeoutSeconds = timerDurationSeconds + 1
)
+ clock := faketime.NewManualClock()
+
tmr := timer{}
w := sleep.Waker{}
- tmr.init(&w)
+ tmr.init(clock, &w)
tmr.enable(timerDurationSeconds * time.Second)
tmr.cleanup()
@@ -39,7 +42,7 @@ func TestCleanup(t *testing.T) {
// The waker should not be asserted.
for i := 0; i < isAssertedTimeoutSeconds; i++ {
- time.Sleep(time.Second)
+ clock.Advance(time.Second)
if w.IsAsserted() {
t.Fatalf("waker asserted unexpectedly")
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index dd5c910ae..cdc344ab7 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -49,6 +49,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f7dd50d35..def9d7186 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -17,6 +17,7 @@ package udp
import (
"io"
"sync/atomic"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,25 +35,26 @@ type udpPacket struct {
destinationAddress tcpip.FullAddress
packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ receivedAt time.Time `state:".(int64)"`
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
// EndpointState represents the state of a UDP endpoint.
-type EndpointState uint32
+type EndpointState tcpip.EndpointState
// Endpoint states. Note that are represented in a netstack-specific manner and
// may not be meaningful externally. Specifically, they need to be translated to
// Linux's representation for these states if presented to userspace.
const (
- StateInitial EndpointState = iota
+ _ EndpointState = iota
+ StateInitial
StateBound
StateConnected
StateClosed
)
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (s EndpointState) String() string {
switch s {
case StateInitial:
@@ -98,7 +100,7 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
// state must be read/set using the EndpointState()/setEndpointState()
// methods.
- state EndpointState
+ state uint32
route *stack.Route `state:"manual"`
dstPort uint16
ttl uint8
@@ -176,7 +178,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// Linux defaults to TTL=1.
multicastTTL: 1,
multicastMemberships: make(map[multicastMembership]struct{}),
- state: StateInitial,
+ state: uint32(StateInitial),
uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
@@ -204,15 +206,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Precondition: e.mu must be held to call this method.
func (e *endpoint) setEndpointState(state EndpointState) {
- atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+ atomic.StoreUint32(&e.state, uint32(state))
}
// EndpointState() returns the current state of the endpoint.
func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ return EndpointState(atomic.LoadUint32(&e.state))
}
-// UniqueID implements stack.TransportEndpoint.UniqueID.
+// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
@@ -226,14 +228,14 @@ func (e *endpoint) LastError() tcpip.Error {
return err
}
-// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+// UpdateLastError implements tcpip.SocketOptionsHandler.
func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
}
-// Abort implements stack.TransportEndpoint.Abort.
+// Abort implements stack.TransportEndpoint.
func (e *endpoint) Abort() {
e.Close()
}
@@ -289,10 +291,10 @@ func (e *endpoint) Close() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
-func (e *endpoint) ModerateRecvBuf(copied int) {}
+// ModerateRecvBuf implements tcpip.Endpoint.
+func (*endpoint) ModerateRecvBuf(int) {}
-// Read implements tcpip.Endpoint.Read.
+// Read implements tcpip.Endpoint.
func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
if err := e.LastError(); err != nil {
return tcpip.ReadResult{}, err
@@ -320,7 +322,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.timestamp,
+ Timestamp: p.receivedAt.UnixNano(),
}
if e.ops.GetReceiveTOS() {
cm.HasTOS = true
@@ -581,21 +583,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return int64(len(v)), nil
}
-// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()
e.portFlags.MostRecent = v
e.mu.Unlock()
}
-// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+// OnReusePortSet implements tcpip.SocketOptionsHandler.
func (e *endpoint) OnReusePortSet(v bool) {
e.mu.Lock()
e.portFlags.LoadBalanced = v
e.mu.Unlock()
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
case tcpip.MTUDiscoverOption:
@@ -629,11 +631,14 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return nil
}
+var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
+
+// HasNIC implements tcpip.SocketOptionsHandler.
func (e *endpoint) HasNIC(id int32) bool {
- return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+ return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch v := opt.(type) {
case *tcpip.MulticastInterfaceOption:
@@ -749,7 +754,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.IPv4TOSOption:
@@ -795,14 +800,14 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
switch o := opt.(type) {
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
- e.multicastNICID,
- e.multicastAddr,
+ NIC: e.multicastNICID,
+ InterfaceAddr: e.multicastAddr,
}
e.mu.Unlock()
@@ -872,7 +877,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
return unwrapped, netProto, nil
}
-// Disconnect implements tcpip.Endpoint.Disconnect.
+// Disconnect implements tcpip.Endpoint.
func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1075,7 +1080,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
BindToDevice: bindToDevice,
Dest: tcpip.FullAddress{},
}
- port, err := e.stack.ReservePort(portRes, nil /* testPort */)
+ port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */)
if err != nil {
return id, bindToDevice, err
}
@@ -1301,12 +1306,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: header.UDP(hdr).SourcePort(),
+ Port: hdr.SourcePort(),
},
destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.LocalAddress,
- Port: header.UDP(hdr).DestinationPort(),
+ Port: hdr.DestinationPort(),
},
data: pkt.Data().ExtractVV(),
}
@@ -1328,7 +1333,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
packet.packetInfo.LocalAddr = localAddr
packet.packetInfo.DestinationAddr = localAddr
packet.packetInfo.NIC = pkt.NICID
- packet.timestamp = e.stack.Clock().NowNanoseconds()
+ packet.receivedAt = e.stack.Clock().Now()
e.rcvMu.Unlock()
@@ -1386,7 +1391,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
}
}
-// State implements tcpip.Endpoint.State.
+// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
return uint32(e.EndpointState())
}
@@ -1405,19 +1410,19 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
-// Wait implements tcpip.Endpoint.Wait.
+// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
}
-// SetOwner implements tcpip.Endpoint.SetOwner.
+// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// SocketOptions implements tcpip.Endpoint.SocketOptions.
+// SocketOptions implements tcpip.Endpoint.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 4aba68b21..1f638c3f6 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,26 +15,38 @@
package udp
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// saveReceivedAt is invoked by stateify.
+func (p *udpPacket) saveReceivedAt() int64 {
+ return p.receivedAt.UnixNano()
+}
+
+// loadReceivedAt is invoked by stateify.
+func (p *udpPacket) loadReceivedAt(nsec int64) {
+ p.receivedAt = time.Unix(0, nsec)
+}
+
// saveData saves udpPacket.data field.
-func (u *udpPacket) saveData() buffer.VectorisedView {
- // We cannot save u.data directly as u.data.views may alias to u.views,
+func (p *udpPacket) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
// which is not allowed by state framework (in-struct pointer).
- return u.data.Clone(nil)
+ return p.data.Clone(nil)
}
// loadData loads udpPacket.data field.
-func (u *udpPacket) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+func (p *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
// here because data.views is not guaranteed to be loaded by now. Plus,
// data.views will be allocated anyway so there really is little point
- // of utilizing u.views for data.views.
- u.data = data
+ // of utilizing p.views for data.views.
+ p.data = data
}
// afterLoad is invoked by stateify.
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 705ad1f64..7c357cb09 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
- ep.state = StateConnected
+ ep.state = uint32(StateConnected)
ep.rcvMu.Lock()
ep.rcvReady = true
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index dc2e3f493..4008cacf2 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,17 +16,16 @@ package udp_test
import (
"bytes"
- "context"
"fmt"
"io/ioutil"
"math/rand"
"testing"
- "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -298,16 +297,18 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- return newDualTestContextWithOptions(t, mtu, stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
- HandleLocal: true,
- })
+ return newDualTestContextWithHandleLocal(t, mtu, true)
}
-func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
+func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
t.Helper()
+ options := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: handleLocal,
+ Clock: &faketime.NullClock{},
+ }
s := stack.New(options)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -378,9 +379,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
+ p, ok := c.linkEP.Read()
if !ok {
c.t.Fatalf("Packet wasn't written out")
return nil
@@ -534,7 +533,9 @@ func newMinPayload(minSize int) []byte {
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}})
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
@@ -606,7 +607,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
case <-ch:
res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
- case <-time.After(300 * time.Millisecond):
+ default:
if packetShouldBeDropped {
return // expected to time out
}
@@ -820,11 +821,7 @@ func TestV4ReadSelfSource(t *testing.T) {
{"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
} {
t.Run(tt.name, func(t *testing.T) {
- c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- HandleLocal: tt.handleLocal,
- })
+ c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal)
defer c.cleanup()
c.createEndpointForFlow(unicastV4)
@@ -1034,17 +1031,17 @@ func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, che
payload := testWriteNoVerify(c, flow, setDest)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
- var udp header.UDP
+ var udpH header.UDP
if flow.isV4() {
- udp = header.UDP(header.IPv4(b).Payload())
+ udpH = header.IPv4(b).Payload()
} else {
- udp = header.UDP(header.IPv6(b).Payload())
+ udpH = header.IPv6(b).Payload()
}
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ if !bytes.Equal(payload, udpH.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload)
}
- return udp.SourcePort()
+ return udpH.SourcePort()
}
func testDualWrite(c *testContext) uint16 {
@@ -1198,7 +1195,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
r.Reset(payload)
n, err := c.ep.Write(&r, writeOpts)
if err != nil {
- c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
+ c.t.Fatalf("c.ep.Write(...) = %s, want nil", err)
}
if got, want := n, int64(len(payload)); got != want {
c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
@@ -1462,7 +1459,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
name: "IPv4 unicast",
proto: header.IPv4ProtocolNumber,
flow: unicastV4,
- expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort},
},
{
name: "IPv4 multicast",
@@ -1474,7 +1471,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
// behaviour. We still include the test so that once the bug is
// resolved, this test will start to fail and the individual tasked
// with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort},
},
{
name: "IPv4 broadcast",
@@ -1486,13 +1483,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
// behaviour. We still include the test so that once the bug is
// resolved, this test will start to fail and the individual tasked
// with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort},
},
{
name: "IPv6 unicast",
proto: header.IPv6ProtocolNumber,
flow: unicastV6,
- expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort},
},
{
name: "IPv6 multicast",
@@ -1504,7 +1501,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) {
// behaviour. We still include the test so that once the bug is
// resolved, this test will start to fail and the individual tasked
// with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
+ expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort},
},
}
@@ -1614,6 +1611,7 @@ func TestTTL(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{p},
+ Clock: &faketime.NullClock{},
})
ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil)
wantTTL = ep.DefaultTTL()
@@ -1759,7 +1757,7 @@ func TestReceiveTosTClass(t *testing.T) {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
}
- want := true
+ const want = true
optionSetter(want)
got := optionGetter()
@@ -1889,18 +1887,14 @@ func TestV4UnknownDestination(t *testing.T) {
}
}
if !tc.icmpRequired {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- if p, ok := c.linkEP.ReadContext(ctx); ok {
+ if p, ok := c.linkEP.Read(); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
return
}
// ICMP required.
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
+ p, ok := c.linkEP.Read()
if !ok {
t.Fatalf("packet wasn't written out")
return
@@ -1987,18 +1981,14 @@ func TestV6UnknownDestination(t *testing.T) {
}
}
if !tc.icmpRequired {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- if p, ok := c.linkEP.ReadContext(ctx); ok {
+ if p, ok := c.linkEP.Read(); ok {
t.Fatalf("unexpected packet received: %+v", p)
}
return
}
// ICMP required.
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
+ p, ok := c.linkEP.Read()
if !ok {
t.Fatalf("packet wasn't written out")
return
@@ -2115,8 +2105,8 @@ func TestShortHeader(t *testing.T) {
Data: buf.ToVectorisedView(),
}))
- if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
- t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want)
}
}
@@ -2124,25 +2114,27 @@ func TestShortHeader(t *testing.T) {
// global and endpoint stats are incremented.
func TestBadChecksumErrors(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV6} {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+ t.Run(flow.String(), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- c.createEndpoint(flow.sockProto())
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
+ c.createEndpoint(flow.sockProto())
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
- payload := newPayload()
- c.injectPacket(flow, payload, true /* badChecksum */)
+ payload := newPayload()
+ c.injectPacket(flow, payload, true /* badChecksum */)
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+ })
}
}
@@ -2484,6 +2476,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
})
e := channel.New(0, defaultMTU, "")
if err := s.CreateNIC(nicID1, e); err != nil {
diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD
index a789c246e..7ff13cf12 100644
--- a/pkg/test/testutil/BUILD
+++ b/pkg/test/testutil/BUILD
@@ -12,6 +12,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/sentry/watchdog",
"//pkg/sync",
"//runsc/config",
"//runsc/specutils",
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 663c83679..f6a3e34c7 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -42,6 +42,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/specutils"
@@ -184,6 +185,7 @@ func TestConfig(t *testing.T) *config.Config {
conf.Network = config.NetworkNone
conf.Strace = true
conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true
+ conf.WatchdogAction = watchdog.Panic
return conf
}
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
index 0e9a829f6..0ef635a2f 100644
--- a/pkg/urpc/urpc.go
+++ b/pkg/urpc/urpc.go
@@ -27,6 +27,7 @@ import (
"os"
"reflect"
"runtime"
+ "time"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
@@ -458,29 +459,45 @@ func (s *Server) StartHandling(client *unet.Socket) {
// No new requests should be initiated after calling Stop. Existing clients
// will be closed after completing any pending RPCs. This method will block
// until all clients have disconnected.
-func (s *Server) Stop() {
- // Wait for all outstanding requests.
- defer s.wg.Wait()
-
+//
+// timeout is the time for clients to complete ongoing RPCs.
+func (s *Server) Stop(timeout time.Duration) {
// Call any Stop callbacks.
for _, stopper := range s.stoppers {
stopper.Stop()
}
- // Close all known clients.
- s.mu.Lock()
- defer s.mu.Unlock()
- for client, state := range s.clients {
- switch state {
- case idle:
- // Close connection now.
- client.Close()
- s.clients[client] = closed
- case processing:
- // Request close when done.
- s.clients[client] = closeRequested
+ done := make(chan bool, 1)
+ go func() {
+ if timeout != 0 {
+ timer := time.NewTicker(timeout)
+ defer timer.Stop()
+ select {
+ case <-done:
+ return
+ case <-timer.C:
+ }
}
- }
+
+ // Close all known clients.
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for client, state := range s.clients {
+ switch state {
+ case idle:
+ // Close connection now.
+ client.Close()
+ s.clients[client] = closed
+ case processing:
+ // Request close when done.
+ s.clients[client] = closeRequested
+ }
+ }
+ }()
+
+ // Wait for all outstanding requests.
+ s.wg.Wait()
+ done <- true
}
// Client is a urpc client.
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index 3dba36f12..54674ee88 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -7,12 +7,14 @@ go_library(
srcs = [
"bytes_io.go",
"bytes_io_unsafe.go",
+ "marshal.go",
"usermem.go",
],
visibility = ["//:sandbox"],
deps = [
"//pkg/atomicbitops",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/gohacks",
"//pkg/hostarch",
"//pkg/safemem",
@@ -29,6 +31,7 @@ go_test(
library = ":usermem",
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/hostarch",
"//pkg/safemem",
"//pkg/syserror",
diff --git a/pkg/usermem/bytes_io.go b/pkg/usermem/bytes_io.go
index 3da3c0294..4c97b9136 100644
--- a/pkg/usermem/bytes_io.go
+++ b/pkg/usermem/bytes_io.go
@@ -16,6 +16,7 @@ package usermem
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
@@ -51,7 +52,7 @@ func (b *BytesIO) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, op
// ZeroOut implements IO.ZeroOut.
func (b *BytesIO) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) {
if toZero > int64(maxInt) {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
rngN, rngErr := b.rangeCheck(addr, int(toZero))
if rngN == 0 {
@@ -89,7 +90,7 @@ func (b *BytesIO) rangeCheck(addr hostarch.Addr, length int) (int, error) {
return 0, nil
}
if length < 0 {
- return 0, syserror.EINVAL
+ return 0, linuxerr.EINVAL
}
max := hostarch.Addr(len(b.Bytes))
if addr >= max {
diff --git a/pkg/usermem/marshal.go b/pkg/usermem/marshal.go
new file mode 100644
index 000000000..5b5a662dc
--- /dev/null
+++ b/pkg/usermem/marshal.go
@@ -0,0 +1,43 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package usermem
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
+)
+
+// IOCopyContext wraps an object implementing hostarch.IO to implement
+// marshal.CopyContext.
+type IOCopyContext struct {
+ Ctx context.Context
+ IO IO
+ Opts IOOpts
+}
+
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (i *IOCopyContext) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) {
+ return i.IO.CopyOut(i.Ctx, addr, b, i.Opts)
+}
+
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) {
+ return i.IO.CopyIn(i.Ctx, addr, b, i.Opts)
+}
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 0d6d25e50..483e64c1e 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -22,11 +22,11 @@ import (
"strconv"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// IO provides access to the contents of a virtual memory space.
@@ -382,7 +382,7 @@ func CopyInt32StringsInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSe
// Parse a single value.
val, err := strconv.ParseInt(string(buf[i:nextI]), 10, 32)
if err != nil {
- return int64(i), syserror.EINVAL
+ return int64(i), linuxerr.EINVAL
}
dsts[j] = int32(val)
@@ -398,7 +398,7 @@ func CopyInt32StringsInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSe
return int64(i), cperr
}
if j == 0 {
- return int64(i), syserror.EINVAL
+ return int64(i), linuxerr.EINVAL
}
return int64(i), nil
}
diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go
index 9b697b593..4fedb57a5 100644
--- a/pkg/usermem/usermem_test.go
+++ b/pkg/usermem/usermem_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
@@ -272,8 +273,8 @@ func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) {
src := BytesIOSequence([]byte(s))
initial := []int32{1, 2}
dsts := append([]int32(nil), initial...)
- if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL {
- t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL)
+ if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) {
+ t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, linuxerr.EINVAL)
}
if !reflect.DeepEqual(dsts, initial) {
t.Errorf("dsts: got %v, wanted %v", dsts, initial)