summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.buildkite/pipeline.yaml9
-rw-r--r--.github/workflows/stale.yml2
-rw-r--r--Makefile27
-rw-r--r--debian/BUILD3
-rw-r--r--g3doc/architecture_guide/platforms.md4
-rw-r--r--nogo.yaml2
-rw-r--r--pkg/abi/linux/fs.go3
-rw-r--r--pkg/abi/linux/netfilter.go19
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go12
-rw-r--r--pkg/abi/linux/ptrace_amd64.go5
-rw-r--r--pkg/abi/linux/ptrace_arm64.go5
-rw-r--r--pkg/coverage/BUILD2
-rw-r--r--pkg/coverage/coverage.go103
-rw-r--r--pkg/gohacks/BUILD10
-rw-r--r--pkg/gohacks/gohacks_test.go97
-rw-r--r--pkg/gohacks/gohacks_unsafe.go14
-rw-r--r--pkg/hostarch/BUILD42
-rw-r--r--pkg/hostarch/access_type.go (renamed from pkg/usermem/access_type.go)2
-rw-r--r--pkg/hostarch/addr.go (renamed from pkg/usermem/addr.go)4
-rw-r--r--pkg/hostarch/addr_range_seq_test.go (renamed from pkg/usermem/addr_range_seq_test.go)2
-rw-r--r--pkg/hostarch/addr_range_seq_unsafe.go (renamed from pkg/usermem/addr_range_seq_unsafe.go)2
-rw-r--r--pkg/hostarch/hostarch.go7
-rw-r--r--pkg/hostarch/hostarch_arm64.go (renamed from pkg/usermem/usermem_arm64.go)2
-rw-r--r--pkg/hostarch/hostarch_x86.go (renamed from pkg/usermem/usermem_x86.go)2
-rw-r--r--pkg/marshal/BUILD2
-rw-r--r--pkg/marshal/marshal.go23
-rw-r--r--pkg/marshal/marshal_impl_util.go8
-rw-r--r--pkg/marshal/primitive/BUILD1
-rw-r--r--pkg/marshal/primitive/primitive.go53
-rw-r--r--pkg/merkletree/BUILD4
-rw-r--r--pkg/merkletree/merkletree.go6
-rw-r--r--pkg/merkletree/merkletree_test.go171
-rw-r--r--pkg/metric/metric.go175
-rw-r--r--pkg/metric/metric.proto11
-rw-r--r--pkg/metric/metric_test.go92
-rw-r--r--pkg/refsvfs2/refs_map.go25
-rw-r--r--pkg/ring0/BUILD2
-rw-r--r--pkg/ring0/defs_amd64.go4
-rw-r--r--pkg/ring0/defs_arm64.go4
-rw-r--r--pkg/ring0/gen_offsets/BUILD2
-rw-r--r--pkg/ring0/kernel_amd64.go35
-rw-r--r--pkg/ring0/kernel_arm64.go8
-rw-r--r--pkg/ring0/lib_amd64.go6
-rw-r--r--pkg/ring0/lib_amd64.s12
-rw-r--r--pkg/ring0/lib_arm64.go6
-rw-r--r--pkg/ring0/lib_arm64.s8
-rw-r--r--pkg/ring0/pagetables/BUILD4
-rw-r--r--pkg/ring0/pagetables/allocator_unsafe.go10
-rw-r--r--pkg/ring0/pagetables/pagetables.go18
-rw-r--r--pkg/ring0/pagetables/pagetables_aarch64.go6
-rw-r--r--pkg/ring0/pagetables/pagetables_amd64_test.go34
-rw-r--r--pkg/ring0/pagetables/pagetables_arm64_test.go42
-rw-r--r--pkg/ring0/pagetables/pagetables_test.go35
-rw-r--r--pkg/ring0/pagetables/pagetables_x86.go6
-rw-r--r--pkg/safecopy/atomic_amd64.s24
-rw-r--r--pkg/safecopy/atomic_arm64.s24
-rw-r--r--pkg/safecopy/memclr_amd64.s6
-rw-r--r--pkg/safecopy/memclr_arm64.s6
-rw-r--r--pkg/safecopy/memcpy_amd64.s6
-rw-r--r--pkg/safecopy/memcpy_arm64.s6
-rw-r--r--pkg/safecopy/safecopy.go22
-rw-r--r--pkg/safecopy/safecopy_test.go64
-rw-r--r--pkg/safecopy/safecopy_unsafe.go12
-rw-r--r--pkg/safecopy/sighandler_amd64.s6
-rw-r--r--pkg/safecopy/sighandler_arm64.s6
-rw-r--r--pkg/safemem/BUILD1
-rw-r--r--pkg/safemem/block_unsafe.go19
-rw-r--r--pkg/seccomp/BUILD2
-rw-r--r--pkg/seccomp/seccomp_test.go4
-rw-r--r--pkg/sentry/arch/BUILD1
-rw-r--r--pkg/sentry/arch/arch.go20
-rw-r--r--pkg/sentry/arch/arch_amd64.go32
-rw-r--r--pkg/sentry/arch/arch_arm64.go28
-rw-r--r--pkg/sentry/arch/auxv.go4
-rw-r--r--pkg/sentry/arch/fpu/BUILD2
-rw-r--r--pkg/sentry/arch/fpu/fpu_amd64.go21
-rw-r--r--pkg/sentry/arch/fpu/fpu_arm64.go2
-rw-r--r--pkg/sentry/arch/signal.go50
-rw-r--r--pkg/sentry/arch/signal_amd64.go8
-rw-r--r--pkg/sentry/arch/signal_arm64.go6
-rw-r--r--pkg/sentry/arch/signal_stack.go10
-rw-r--r--pkg/sentry/arch/stack.go44
-rw-r--r--pkg/sentry/arch/stack_unsafe.go6
-rw-r--r--pkg/sentry/devices/memdev/zero.go1
-rw-r--r--pkg/sentry/devices/tundev/BUILD1
-rw-r--r--pkg/sentry/devices/tundev/tundev.go5
-rw-r--r--pkg/sentry/fs/BUILD1
-rw-r--r--pkg/sentry/fs/anon/BUILD2
-rw-r--r--pkg/sentry/fs/anon/anon.go4
-rw-r--r--pkg/sentry/fs/copy_up.go3
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/dev.go11
-rw-r--r--pkg/sentry/fs/dev/net_tun.go5
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD1
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_test.go4
-rw-r--r--pkg/sentry/fs/fsutil/BUILD2
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set.go4
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set_test.go12
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go8
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go6
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go13
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go21
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go51
-rw-r--r--pkg/sentry/fs/gofer/BUILD1
-rw-r--r--pkg/sentry/fs/gofer/attr.go4
-rw-r--r--pkg/sentry/fs/host/socket.go12
-rw-r--r--pkg/sentry/fs/inotify.go3
-rw-r--r--pkg/sentry/fs/inotify_event.go9
-rw-r--r--pkg/sentry/fs/offset.go4
-rw-r--r--pkg/sentry/fs/overlay.go10
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/exec_args.go7
-rw-r--r--pkg/sentry/fs/proc/inode.go4
-rw-r--r--pkg/sentry/fs/proc/meminfo.go4
-rw-r--r--pkg/sentry/fs/proc/net.go20
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD1
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile.go3
-rw-r--r--pkg/sentry/fs/proc/sys_net.go21
-rw-r--r--pkg/sentry/fs/proc/task.go13
-rw-r--r--pkg/sentry/fs/proc/uid_gid_map.go3
-rw-r--r--pkg/sentry/fs/ramfs/BUILD2
-rw-r--r--pkg/sentry/fs/ramfs/tree.go4
-rw-r--r--pkg/sentry/fs/sys/BUILD2
-rw-r--r--pkg/sentry/fs/sys/sys.go6
-rw-r--r--pkg/sentry/fs/timerfd/BUILD1
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD2
-rw-r--r--pkg/sentry/fs/tmpfs/file_test.go3
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go27
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go16
-rw-r--r--pkg/sentry/fs/tty/BUILD1
-rw-r--r--pkg/sentry/fs/tty/dir.go3
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/BUILD48
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/base.go261
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go425
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpu.go70
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpuacct.go114
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpuset.go39
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/job.go64
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/memory.go74
-rw-r--r--pkg/sentry/fsimpl/eventfd/BUILD1
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go7
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD2
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go20
-rw-r--r--pkg/sentry/fsimpl/fuse/request_response.go20
-rw-r--r--pkg/sentry/fsimpl/fuse/utils_test.go13
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go39
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go252
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go6
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go39
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go4
-rw-r--r--pkg/sentry/fsimpl/host/BUILD1
-rw-r--r--pkg/sentry/fsimpl/host/host.go8
-rw-r--r--pkg/sentry/fsimpl/host/save_restore.go11
-rw-r--r--pkg/sentry/fsimpl/host/socket.go19
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/mmap_util.go14
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD1
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go6
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go9
-rw-r--r--pkg/sentry/fsimpl/pipefs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go4
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/proc/task.go23
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go60
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go20
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go19
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go20
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go11
-rw-r--r--pkg/sentry/fsimpl/proc/yama.go3
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go8
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD1
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go4
-rw-r--r--pkg/sentry/fsimpl/timerfd/BUILD1
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go26
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go14
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD3
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go102
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go314
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go3
-rw-r--r--pkg/sentry/hostmm/BUILD2
-rw-r--r--pkg/sentry/hostmm/hostmm.go6
-rw-r--r--pkg/sentry/kernel/BUILD6
-rw-r--r--pkg/sentry/kernel/cgroup.go281
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD1
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go7
-rw-r--r--pkg/sentry/kernel/futex/BUILD3
-rw-r--r--pkg/sentry/kernel/futex/futex.go48
-rw-r--r--pkg/sentry/kernel/futex/futex_test.go16
-rw-r--r--pkg/sentry/kernel/kcov.go8
-rw-r--r--pkg/sentry/kernel/kernel.go52
-rw-r--r--pkg/sentry/kernel/pipe/BUILD1
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go6
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go21
-rw-r--r--pkg/sentry/kernel/ptrace.go11
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go3
-rw-r--r--pkg/sentry/kernel/ptrace_arm64.go4
-rw-r--r--pkg/sentry/kernel/rseq.go35
-rw-r--r--pkg/sentry/kernel/seccomp.go10
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go30
-rw-r--r--pkg/sentry/kernel/syscalls.go8
-rw-r--r--pkg/sentry/kernel/task.go18
-rw-r--r--pkg/sentry/kernel/task_cgroup.go138
-rw-r--r--pkg/sentry/kernel/task_clone.go11
-rw-r--r--pkg/sentry/kernel/task_exit.go4
-rw-r--r--pkg/sentry/kernel/task_futex.go27
-rw-r--r--pkg/sentry/kernel/task_image.go4
-rw-r--r--pkg/sentry/kernel/task_log.go13
-rw-r--r--pkg/sentry/kernel/task_run.go8
-rw-r--r--pkg/sentry/kernel/task_signals.go16
-rw-r--r--pkg/sentry/kernel/task_start.go9
-rw-r--r--pkg/sentry/kernel/task_syscall.go12
-rw-r--r--pkg/sentry/kernel/task_usermem.go67
-rw-r--r--pkg/sentry/kernel/threads.go9
-rw-r--r--pkg/sentry/kernel/timekeeper_test.go4
-rw-r--r--pkg/sentry/kernel/vdso.go4
-rw-r--r--pkg/sentry/loader/BUILD1
-rw-r--r--pkg/sentry/loader/elf.go49
-rw-r--r--pkg/sentry/loader/loader.go11
-rw-r--r--pkg/sentry/loader/vdso.go25
-rw-r--r--pkg/sentry/memmap/BUILD3
-rw-r--r--pkg/sentry/memmap/mapping_set.go28
-rw-r--r--pkg/sentry/memmap/mapping_set_test.go65
-rw-r--r--pkg/sentry/memmap/memmap.go43
-rw-r--r--pkg/sentry/mm/BUILD14
-rw-r--r--pkg/sentry/mm/address_space.go10
-rw-r--r--pkg/sentry/mm/aio_context.go21
-rw-r--r--pkg/sentry/mm/io.go75
-rw-r--r--pkg/sentry/mm/lifecycle.go4
-rw-r--r--pkg/sentry/mm/metadata.go18
-rw-r--r--pkg/sentry/mm/mm.go24
-rw-r--r--pkg/sentry/mm/mm_test.go35
-rw-r--r--pkg/sentry/mm/pma.go74
-rw-r--r--pkg/sentry/mm/procfs.go20
-rw-r--r--pkg/sentry/mm/shm.go6
-rw-r--r--pkg/sentry/mm/special_mappable.go14
-rw-r--r--pkg/sentry/mm/syscalls.go106
-rw-r--r--pkg/sentry/mm/vma.go86
-rw-r--r--pkg/sentry/pgalloc/BUILD3
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go34
-rw-r--r--pkg/sentry/pgalloc/pgalloc_test.go6
-rw-r--r--pkg/sentry/pgalloc/save_restore.go10
-rw-r--r--pkg/sentry/platform/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/BUILD5
-rw-r--r--pkg/sentry/platform/kvm/address_space.go18
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go13
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go4
-rw-r--r--pkg/sentry/platform/kvm/context.go6
-rw-r--r--pkg/sentry/platform/kvm/kvm.go10
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go37
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.s (renamed from pkg/tcpip/transport/tcp/cubic_state.go)22
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go8
-rw-r--r--pkg/sentry/platform/kvm/machine.go8
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go46
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go18
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go14
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go10
-rw-r--r--pkg/sentry/platform/kvm/virtual_map.go6
-rw-r--r--pkg/sentry/platform/kvm/virtual_map_test.go14
-rw-r--r--pkg/sentry/platform/mmap_min_addr.go8
-rw-r--r--pkg/sentry/platform/platform.go43
-rw-r--r--pkg/sentry/platform/ptrace/BUILD2
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go26
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go4
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s6
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s6
-rw-r--r--pkg/sentry/platform/ptrace/stub_unsafe.go17
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go64
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD4
-rw-r--r--pkg/sentry/socket/control/control.go26
-rw-r--r--pkg/sentry/socket/control/control_test.go6
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go19
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go3
-rw-r--r--pkg/sentry/socket/hostinet/stack.go10
-rw-r--r--pkg/sentry/socket/netfilter/BUILD2
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go4
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go4
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go4
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go20
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go6
-rw-r--r--pkg/sentry/socket/netfilter/targets.go212
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go6
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go6
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/message.go14
-rw-r--r--pkg/sentry/socket/netlink/socket.go9
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go112
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go7
-rw-r--r--pkg/sentry/socket/socket.go21
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go10
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go2
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go3
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless_state.go2
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go31
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/epoll.go7
-rw-r--r--pkg/sentry/strace/poll.go5
-rw-r--r--pkg/sentry/strace/select.go5
-rw-r--r--pkg/sentry/strace/signal.go9
-rw-r--r--pkg/sentry/strace/socket.go33
-rw-r--r--pkg/sentry/strace/strace.go39
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/error.go12
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go6
-rw-r--r--pkg/sentry/syscalls/linux/sigset.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go17
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go34
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_mempolicy.go15
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go15
-rw-r--r--pkg/sentry/syscalls/linux/sys_mount.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go18
-rw-r--r--pkg/sentry/syscalls/linux/sys_random.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_seccomp.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go50
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go9
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go12
-rw-r--r--pkg/sentry/syscalls/linux/timespec.go28
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go16
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go19
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go21
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go7
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go25
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go11
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/signal.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go51
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go11
-rw-r--r--pkg/sentry/time/BUILD1
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/anonfs.go4
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go3
-rw-r--r--pkg/sentry/vfs/filesystem_impl_util.go4
-rw-r--r--pkg/sentry/vfs/inotify.go11
-rw-r--r--pkg/sentry/vfs/mount.go17
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/sync/generic_seqatomic_unsafe.go3
-rw-r--r--pkg/sync/runtime_unsafe.go14
-rw-r--r--pkg/sync/seqatomictest/BUILD1
-rw-r--r--pkg/tcpip/BUILD31
-rw-r--r--pkg/tcpip/checker/checker.go20
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins.go20
-rw-r--r--pkg/tcpip/header/BUILD2
-rw-r--r--pkg/tcpip/header/eth_test.go3
-rw-r--r--pkg/tcpip/header/igmp_test.go6
-rw-r--r--pkg/tcpip/header/ipv4.go58
-rw-r--r--pkg/tcpip/header/ipv4_test.go75
-rw-r--r--pkg/tcpip/header/ipv6.go87
-rw-r--r--pkg/tcpip/header/ipv6_test.go104
-rw-r--r--pkg/tcpip/header/ndp_test.go11
-rw-r--r--pkg/tcpip/header/tcp.go53
-rw-r--r--pkg/tcpip/header/udp.go27
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go14
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go57
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go27
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go46
-rw-r--r--pkg/tcpip/network/ip_test.go104
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go12
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go124
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go117
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go105
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go109
-rw-r--r--pkg/tcpip/network/ipv6/mld.go22
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go157
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go10
-rw-r--r--pkg/tcpip/network/multicast_group_test.go30
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go26
-rw-r--r--pkg/tcpip/ports/ports_test.go36
-rw-r--r--pkg/tcpip/socketops.go58
-rw-r--r--pkg/tcpip/stack/BUILD4
-rw-r--r--pkg/tcpip/stack/conntrack.go234
-rw-r--r--pkg/tcpip/stack/hook_string.go41
-rw-r--r--pkg/tcpip/stack/iptables.go7
-rw-r--r--pkg/tcpip/stack/iptables_targets.go78
-rw-r--r--pkg/tcpip/stack/ndp_test.go20
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go8
-rw-r--r--pkg/tcpip/stack/packet_buffer.go11
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go309
-rw-r--r--pkg/tcpip/stack/stack_global_state.go72
-rw-r--r--pkg/tcpip/stack/stack_options.go4
-rw-r--r--pkg/tcpip/stack/stack_test.go71
-rw-r--r--pkg/tcpip/stack/tcp.go451
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go17
-rw-r--r--pkg/tcpip/stack/transport_test.go4
-rw-r--r--pkg/tcpip/tcpip.go59
-rw-r--r--pkg/tcpip/tests/integration/BUILD6
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go194
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go13
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go16
-rw-r--r--pkg/tcpip/tests/integration/route_test.go5
-rw-r--r--pkg/tcpip/tests/utils/utils.go8
-rw-r--r--pkg/tcpip/testutil/BUILD18
-rw-r--r--pkg/tcpip/testutil/testutil.go43
-rw-r--r--pkg/tcpip/testutil/testutil_test.go103
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go53
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go74
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go25
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go76
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/tcp/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/accept.go291
-rw-r--r--pkg/tcpip/transport/tcp/connect.go134
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go119
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go2
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go880
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go82
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go74
-rw-r--r--pkg/tcpip/transport/tcp/rack.go129
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go173
-rw-r--r--pkg/tcpip/transport/tcp/reno.go30
-rw-r--r--pkg/tcpip/transport/tcp/reno_recovery.go14
-rw-r--r--pkg/tcpip/transport/tcp/sack_recovery.go18
-rw-r--r--pkg/tcpip/transport/tcp/segment.go14
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go4
-rw-r--r--pkg/tcpip/transport/tcp/snd.go440
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go20
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go243
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go4
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go102
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go34
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go3
-rw-r--r--pkg/test/dockerutil/BUILD2
-rw-r--r--pkg/test/dockerutil/container.go9
-rw-r--r--pkg/usermem/BUILD24
-rw-r--r--pkg/usermem/bytes_io.go21
-rw-r--r--pkg/usermem/bytes_io_unsafe.go7
-rw-r--r--pkg/usermem/usermem.go52
-rw-r--r--pkg/usermem/usermem_test.go9
-rw-r--r--runsc/BUILD2
-rw-r--r--runsc/boot/BUILD3
-rw-r--r--runsc/boot/controller.go2
-rw-r--r--runsc/boot/fs.go48
-rw-r--r--runsc/boot/loader.go16
-rw-r--r--runsc/boot/loader_test.go17
-rw-r--r--runsc/boot/vfs.go94
-rw-r--r--runsc/cli/BUILD2
-rw-r--r--runsc/cli/main.go11
-rw-r--r--runsc/cmd/BUILD2
-rw-r--r--runsc/cmd/do.go108
-rw-r--r--runsc/cmd/gofer.go6
-rw-r--r--runsc/cmd/mitigate.go49
-rw-r--r--runsc/cmd/mitigate_extras.go (renamed from pkg/tcpip/transport/tcp/rack_state.go)22
-rw-r--r--runsc/cmd/mitigate_test.go7
-rw-r--r--runsc/cmd/symbolize.go6
-rw-r--r--runsc/cmd/verity_prepare.go108
-rw-r--r--runsc/config/config.go6
-rw-r--r--runsc/config/flags.go4
-rw-r--r--runsc/container/BUILD4
-rw-r--r--runsc/container/container.go2
-rw-r--r--runsc/mitigate/mitigate.go5
-rw-r--r--runsc/mitigate/mitigate_test.go13
-rw-r--r--runsc/mitigate/mock/mock.go31
-rw-r--r--runsc/sandbox/BUILD1
-rw-r--r--runsc/sandbox/sandbox.go45
-rw-r--r--runsc/specutils/fs.go18
-rw-r--r--runsc/specutils/specutils.go16
-rw-r--r--shim/BUILD1
-rw-r--r--test/benchmarks/base/BUILD3
-rw-r--r--test/benchmarks/database/BUILD1
-rw-r--r--test/benchmarks/fs/BUILD2
-rw-r--r--test/benchmarks/media/BUILD1
-rw-r--r--test/benchmarks/ml/BUILD1
-rw-r--r--test/benchmarks/network/BUILD5
-rw-r--r--test/e2e/BUILD3
-rw-r--r--test/e2e/integration_test.go77
-rw-r--r--test/e2e/regression_test.go47
-rw-r--r--test/fsstress/BUILD4
-rw-r--r--test/fsstress/fsstress_test.go43
-rw-r--r--test/image/image_test.go5
-rw-r--r--test/iptables/BUILD2
-rw-r--r--test/iptables/iptables_test.go8
-rw-r--r--test/iptables/iptables_util.go61
-rw-r--r--test/iptables/nat.go128
-rw-r--r--test/packetdrill/BUILD1
-rw-r--r--test/packetimpact/runner/defs.bzl6
-rw-r--r--test/packetimpact/runner/dut.go30
-rw-r--r--test/packetimpact/testbench/BUILD2
-rw-r--r--test/packetimpact/testbench/rawsockets.go4
-rw-r--r--test/packetimpact/tests/BUILD27
-rw-r--r--test/packetimpact/tests/tcp_info_test.go6
-rw-r--r--test/packetimpact/tests/tcp_listen_backlog_test.go86
-rw-r--r--test/packetimpact/tests/tcp_rack_test.go6
-rw-r--r--test/packetimpact/tests/tcp_retransmits_test.go4
-rw-r--r--test/packetimpact/tests/tcp_syncookie_test.go70
-rw-r--r--test/perf/BUILD9
-rw-r--r--test/perf/linux/getpid_benchmark.cc18
-rw-r--r--test/perf/linux/write_benchmark.cc12
-rw-r--r--test/runner/defs.bzl105
-rw-r--r--test/runner/runner.go1
-rw-r--r--test/runtimes/defs.bzl1
-rw-r--r--test/syscalls/BUILD18
-rw-r--r--test/syscalls/linux/32bit.cc8
-rw-r--r--test/syscalls/linux/BUILD56
-rw-r--r--test/syscalls/linux/accept_bind.cc70
-rw-r--r--test/syscalls/linux/alarm.cc8
-rw-r--r--test/syscalls/linux/cgroup.cc452
-rw-r--r--test/syscalls/linux/chmod.cc10
-rw-r--r--test/syscalls/linux/dev.cc2
-rw-r--r--test/syscalls/linux/epoll.cc4
-rw-r--r--test/syscalls/linux/eventfd.cc2
-rw-r--r--test/syscalls/linux/flock.cc12
-rw-r--r--test/syscalls/linux/fpsig_fork.cc57
-rw-r--r--test/syscalls/linux/futex.cc34
-rw-r--r--test/syscalls/linux/inotify.cc18
-rw-r--r--test/syscalls/linux/itimer.cc6
-rw-r--r--test/syscalls/linux/open.cc4
-rw-r--r--test/syscalls/linux/open_create.cc10
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc6
-rw-r--r--test/syscalls/linux/partial_bad_buffer.cc2
-rw-r--r--test/syscalls/linux/ping_socket.cc2
-rw-r--r--test/syscalls/linux/pipe.cc2
-rw-r--r--test/syscalls/linux/poll.cc2
-rw-r--r--test/syscalls/linux/ppoll.cc2
-rw-r--r--test/syscalls/linux/pread64.cc2
-rw-r--r--test/syscalls/linux/proc.cc2
-rw-r--r--test/syscalls/linux/proc_net.cc26
-rw-r--r--test/syscalls/linux/proc_net_unix.cc83
-rw-r--r--test/syscalls/linux/proc_pid_uid_gid_map.cc10
-rw-r--r--test/syscalls/linux/pselect.cc2
-rw-r--r--test/syscalls/linux/ptrace.cc5
-rw-r--r--test/syscalls/linux/raw_socket.cc12
-rw-r--r--test/syscalls/linux/read.cc2
-rw-r--r--test/syscalls/linux/readv.cc2
-rw-r--r--test/syscalls/linux/select.cc4
-rw-r--r--test/syscalls/linux/semaphore.cc14
-rw-r--r--test/syscalls/linux/sendfile.cc4
-rw-r--r--test/syscalls/linux/sigtimedwait.cc4
-rw-r--r--test/syscalls/linux/socket.cc6
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc48
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc1032
-rw-r--r--test/syscalls/linux/socket_inet_loopback_nogotsan.cc23
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc2
-rw-r--r--test/syscalls/linux/socket_ip_unbound_netlink.cc10
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc620
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc324
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc17
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc49
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound.cc20
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc29
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc6
-rw-r--r--test/syscalls/linux/socket_stream_blocking.cc4
-rw-r--r--test/syscalls/linux/socket_test_util.cc46
-rw-r--r--test/syscalls/linux/socket_test_util.h14
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.cc2
-rw-r--r--test/syscalls/linux/splice.cc2
-rw-r--r--test/syscalls/linux/symlink.cc4
-rw-r--r--test/syscalls/linux/tcp_socket.cc190
-rw-r--r--test/syscalls/linux/timerfd.cc2
-rw-r--r--test/syscalls/linux/truncate.cc4
-rw-r--r--test/syscalls/linux/tuntap.cc15
-rw-r--r--test/syscalls/linux/udp_bind.cc41
-rw-r--r--test/syscalls/linux/udp_socket.cc253
-rw-r--r--test/syscalls/linux/unlink.cc4
-rw-r--r--test/syscalls/linux/verity_ioctl.cc188
-rw-r--r--test/syscalls/linux/verity_mount.cc53
-rw-r--r--test/syscalls/linux/vfork.cc4
-rw-r--r--test/syscalls/linux/xattr.cc4
-rw-r--r--test/util/BUILD18
-rw-r--r--test/util/cgroup_util.cc236
-rw-r--r--test/util/cgroup_util.h119
-rw-r--r--test/util/fs_util.cc44
-rw-r--r--test/util/fs_util.h12
-rw-r--r--test/util/save_util.cc26
-rw-r--r--tools/BUILD8
-rw-r--r--tools/bazeldefs/go.bzl2
-rw-r--r--tools/bigquery/BUILD1
-rw-r--r--tools/bigquery/bigquery.go8
-rw-r--r--tools/deps.bzl114
-rw-r--r--tools/go_marshal/defs.bzl3
-rw-r--r--tools/go_marshal/gomarshal/generator.go4
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go24
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go14
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_dynamic.go12
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go60
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go51
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go6
-rw-r--r--tools/go_marshal/test/BUILD3
-rw-r--r--tools/go_marshal/test/benchmark_test.go80
-rw-r--r--tools/go_marshal/test/escape/BUILD2
-rw-r--r--tools/go_marshal/test/escape/escape.go22
-rw-r--r--tools/go_marshal/test/marshal_test.go57
-rw-r--r--tools/nogo/analyzers.go6
-rw-r--r--tools/nogo/check/main.go17
-rw-r--r--tools/nogo/defs.bzl46
-rw-r--r--website/BUILD4
628 files changed, 13268 insertions, 7371 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
index 3bc5041c0..c1b478dc3 100644
--- a/.buildkite/pipeline.yaml
+++ b/.buildkite/pipeline.yaml
@@ -55,6 +55,9 @@ steps:
# Basic unit tests.
- <<: *common
+ label: ":golang: Nogo tests"
+ command: make nogo-tests
+ - <<: *common
label: ":test_tube: Unit tests"
command: make unit-tests
- <<: *common
@@ -69,9 +72,6 @@ steps:
# Integration tests.
- <<: *common
- label: ":parachute: FUSE tests"
- command: make fuse-tests
- - <<: *common
label: ":docker: Docker tests"
command: make docker-tests
- <<: *common
@@ -90,6 +90,9 @@ steps:
label: ":person_in_lotus_position: KVM tests"
command: make kvm-tests
- <<: *common
+ label: ":weight_lifter: Fsstress test"
+ command: make fsstress-test
+ - <<: *common
label: ":docker: Containerd 1.3.9 tests"
command: make containerd-test-1.3.9
- <<: *common
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 3a4aa22e2..a9e0a4717 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -15,7 +15,7 @@ jobs:
stale-issue-label: 'stale'
stale-pr-label: 'stale'
exempt-issue-labels: 'exported, type: bug, type: cleanup, type: enhancement, type: process, type: proposal, type: question'
- exempt-pr-labels: 'ready to pull'
+ exempt-pr-labels: 'ready to pull, exported'
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
days-before-stale: 90
diff --git a/Makefile b/Makefile
index 0f79b6a18..e32d1b99e 100644
--- a/Makefile
+++ b/Makefile
@@ -144,6 +144,7 @@ dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo.
@$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile)
@$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2)
@$(call configure_noreload,$(RUNTIME)-vfs2-fuse-d,--net-raw --debug --strace --log-packets --vfs2 --fuse)
+ @$(call configure_noreload,$(RUNTIME)-vfs2-cgroup-d,--net-raw --debug --strace --log-packets --vfs2 --cgroupfs)
@$(call reload_docker)
.PHONY: dev
@@ -179,12 +180,12 @@ smoke-tests: ## Runs a simple smoke test after build runsc.
@$(call run,//runsc,--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true)
.PHONY: smoke-tests
-fuse-tests:
- @$(call test,--test_tag_filters=fuse $(PARTITIONS) test/fuse/...)
-.PHONY: fuse-tests
+nogo-tests:
+ @$(call test,--build_tag_filters=nogo --test_tag_filters=nogo //:all pkg/... tools/...)
+.PHONY: nogo-tests
unit-tests: ## Local package unit tests in pkg/..., tools/.., etc.
- @$(call test,//:all pkg/... tools/...)
+ @$(call test,--build_tag_filters=-nogo --test_tag_filters=-nogo //:all pkg/... tools/...)
.PHONY: unit-tests
runsc-tests: ## Run all tests in runsc/...
@@ -192,7 +193,7 @@ runsc-tests: ## Run all tests in runsc/...
.PHONY: runsc-tests
tests: ## Runs all unit tests and syscall tests.
-tests: unit-tests runsc-tests syscall-tests
+tests: unit-tests nogo-tests runsc-tests syscall-tests
.PHONY: tests
integration-tests: ## Run all standard integration tests.
@@ -204,6 +205,9 @@ network-tests: ## Run all networking integration tests.
network-tests: iptables-tests packetdrill-tests packetimpact-tests
.PHONY: network-tests
+# The set of system call targets.
+SYSCALL_TARGETS := test/syscalls/... test/fuse/...
+
syscall-%-tests:
@$(call test,--test_tag_filters=runsc_$* $(PARTITIONS) test/syscalls/...)
@@ -212,7 +216,8 @@ syscall-native-tests:
.PHONY: syscall-native-tests
syscall-tests: ## Run all system call tests.
- @$(call test,$(PARTITIONS) test/syscalls/...)
+ @$(call test,$(PARTITIONS) $(SYSCALL_TARGETS))
+.PHONY: syscall-tests
%-runtime-tests: load-runtimes_% $(RUNTIME_BIN)
@$(call install_runtime,$(RUNTIME),) # Ensure flags are cleared.
@@ -340,7 +345,8 @@ BENCHMARKS_FILTER := .
BENCHMARKS_OPTIONS := -test.benchtime=30s
BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS)
BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex
-BENCH_RUNTIME_ARGS ?= --vfs2
+BENCH_VFS := --vfs2
+BENCH_RUNTIME_ARGS ?=
init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema.
@$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE))
@@ -361,13 +367,14 @@ run_benchmark = \
benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
@$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \
- $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \
- ) true
+ $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS) --vfs2) && \
+ $(call run_benchmark,$(PLATFORM)_vfs1,--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \
+ ) true
@$(call run_benchmark,runc)
.PHONY: benchmark-platforms
run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery.
- @$(call run_benchmark,$(RUNTIME),$(BENCH_RUNTIME_ARGS))
+ @$(call run_benchmark,$(RUNTIME)$(BENCH_VFS),$(BENCH_RUNTIME_ARGS) $(BENCH_VFS))
.PHONY: run-benchmark
##
diff --git a/debian/BUILD b/debian/BUILD
index 64aa2369a..32cc209bf 100644
--- a/debian/BUILD
+++ b/debian/BUILD
@@ -29,6 +29,9 @@ pkg_deb(
arm64 = "arm64",
),
changes = "runsc.changes",
+ conffiles = [
+ "/etc/containerd/runsc.toml",
+ ],
data = ":debian-data",
deb = "runsc.deb",
# Note that the description_file will be flatten (all newlines removed),
diff --git a/g3doc/architecture_guide/platforms.md b/g3doc/architecture_guide/platforms.md
index d112c9a28..e19c77236 100644
--- a/g3doc/architecture_guide/platforms.md
+++ b/g3doc/architecture_guide/platforms.md
@@ -18,8 +18,8 @@ type Context interface {
}
type AddressSpace interface {
- MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, ...) error
- Unmap(addr usermem.Addr, length uint64)
+ MapFile(addr hostarch.Addr, f File, fr FileRange, at hostarch.AccessType, ...) error
+ Unmap(addr hostarch.Addr, length uint64)
}
```
diff --git a/nogo.yaml b/nogo.yaml
index c0445a837..1e72d9e29 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -55,8 +55,6 @@ global:
# Same story for underscores.
- "should not use ALL_CAPS in Go names"
- "should not use underscores in Go names"
- # TODO(b/179817829): Upgrade to flock to v0.8.0.
- - "flock.NewFlock is deprecated: Use New instead"
exclude:
# Generated: exempt all.
- pkg/shim/runtimeoptions/runtimeoptions_cri.go
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 0d921ed6f..cad24fcc7 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -19,8 +19,10 @@ package linux
// See linux/magic.h.
const (
ANON_INODE_FS_MAGIC = 0x09041934
+ CGROUP_SUPER_MAGIC = 0x27e0eb
DEVPTS_SUPER_MAGIC = 0x00001cd1
EXT_SUPER_MAGIC = 0xef53
+ FUSE_SUPER_MAGIC = 0x65735546
OVERLAYFS_SUPER_MAGIC = 0x794c7630
PIPEFS_MAGIC = 0x50495045
PROC_SUPER_MAGIC = 0x9fa0
@@ -29,7 +31,6 @@ const (
SYSFS_MAGIC = 0x62656572
TMPFS_MAGIC = 0x01021994
V9FS_MAGIC = 0x01021997
- FUSE_SUPER_MAGIC = 0x65735546
)
// Filesystem path limits, from uapi/linux/limits.h.
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 378f1baf3..35c632168 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -145,13 +145,13 @@ func (ke *KernelIPTEntry) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
- ke.Entry.MarshalBytes(dst)
+ ke.Entry.MarshalUnsafe(dst)
ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
}
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
- ke.Entry.UnmarshalBytes(src)
+ ke.Entry.UnmarshalUnsafe(src)
ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
@@ -375,6 +375,17 @@ type XTRedirectTarget struct {
// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
const SizeOfXTRedirectTarget = 56
+// XTSNATTarget triggers Source NAT when reached.
+// Adding 4 bytes of padding to make the struct 8 byte aligned.
+type XTSNATTarget struct {
+ Target XTEntryTarget
+ NfRange NfNATIPV4MultiRangeCompat
+ _ [4]byte
+}
+
+// SizeOfXTSNATTarget is the size of an XTSNATTarget.
+const SizeOfXTSNATTarget = 56
+
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
//
@@ -429,7 +440,7 @@ func (ke *KernelIPTGetEntries) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
- ke.IPTGetEntries.MarshalBytes(dst)
+ ke.IPTGetEntries.MarshalUnsafe(dst)
marshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
@@ -439,7 +450,7 @@ func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
- ke.IPTGetEntries.UnmarshalBytes(src)
+ ke.IPTGetEntries.UnmarshalUnsafe(src)
unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index b953e62dc..f7c70b430 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -86,7 +86,7 @@ func (ke *KernelIP6TGetEntries) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) {
- ke.IPTGetEntries.MarshalBytes(dst)
+ ke.IPTGetEntries.MarshalUnsafe(dst)
marshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
@@ -96,7 +96,7 @@ func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) {
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) {
- ke.IPTGetEntries.UnmarshalBytes(src)
+ ke.IPTGetEntries.UnmarshalUnsafe(src)
unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
for i := range ke.Entrytable {
ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
@@ -149,8 +149,8 @@ type IP6TEntry struct {
const SizeOfIP6TEntry = 168
// KernelIP6TEntry is identical to IP6TEntry, but includes the Elems field.
-// KernelIP6TEntry itself is not Marshallable but it implements some methods of
-// marshal.Marshallable that help in other implementations of Marshallable.
+//
+// +marshal dynamic
type KernelIP6TEntry struct {
Entry IP6TEntry
@@ -168,13 +168,13 @@ func (ke *KernelIP6TEntry) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) {
- ke.Entry.MarshalBytes(dst)
+ ke.Entry.MarshalUnsafe(dst)
ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
}
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) {
- ke.Entry.UnmarshalBytes(src)
+ ke.Entry.UnmarshalUnsafe(src)
ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
index 50e22fe7e..e722971f1 100644
--- a/pkg/abi/linux/ptrace_amd64.go
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -61,3 +61,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 {
func (p *PtraceRegs) StackPointer() uint64 {
return p.Rsp
}
+
+// SetStackPointer sets the stack pointer to the specified value.
+func (p *PtraceRegs) SetStackPointer(sp uint64) {
+ p.Rsp = sp
+}
diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go
index da36811d2..3d0906565 100644
--- a/pkg/abi/linux/ptrace_arm64.go
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -38,3 +38,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 {
func (p *PtraceRegs) StackPointer() uint64 {
return p.Sp
}
+
+// SetStackPointer sets the stack pointer to the specified value.
+func (p *PtraceRegs) SetStackPointer(sp uint64) {
+ p.Sp = sp
+}
diff --git a/pkg/coverage/BUILD b/pkg/coverage/BUILD
index a198e8028..ace5895f8 100644
--- a/pkg/coverage/BUILD
+++ b/pkg/coverage/BUILD
@@ -7,8 +7,8 @@ go_library(
srcs = ["coverage.go"],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/hostarch",
"//pkg/sync",
- "//pkg/usermem",
"@io_bazel_rules_go//go/tools/coverdata",
],
)
diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go
index 6f3d72e83..b33a20802 100644
--- a/pkg/coverage/coverage.go
+++ b/pkg/coverage/coverage.go
@@ -26,20 +26,25 @@ import (
"fmt"
"io"
"sort"
+ "sync/atomic"
"testing"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
"github.com/bazelbuild/rules_go/go/tools/coverdata"
)
-// coverageMu must be held while accessing coverdata.Cover. This prevents
-// concurrent reads/writes from multiple threads collecting coverage data.
-var coverageMu sync.RWMutex
+var (
+ // coverageMu must be held while accessing coverdata.Cover. This prevents
+ // concurrent reads/writes from multiple threads collecting coverage data.
+ coverageMu sync.RWMutex
-// once ensures that globalData is only initialized once.
-var once sync.Once
+ // reportOutput is the place to write out a coverage report. It should be
+ // closed after the report is written. It is protected by reportOutputMu.
+ reportOutput io.WriteCloser
+ reportOutputMu sync.Mutex
+)
// blockBitLength is the number of bits used to represent coverage block index
// in a synthetic PC (the rest are used to represent the file index). Even
@@ -51,12 +56,26 @@ var once sync.Once
// file and every block.
const blockBitLength = 16
-// KcovAvailable returns whether the kcov coverage interface is available. It is
-// available as long as coverage is enabled for some files.
-func KcovAvailable() bool {
+// Available returns whether any coverage data is available.
+func Available() bool {
return len(coverdata.Cover.Blocks) > 0
}
+// EnableReport sets up coverage reporting.
+func EnableReport(w io.WriteCloser) {
+ reportOutputMu.Lock()
+ defer reportOutputMu.Unlock()
+ reportOutput = w
+}
+
+// KcovSupported returns whether the kcov interface should be made available.
+//
+// If coverage reporting is on, do not turn on kcov, which will consume
+// coverage data.
+func KcovSupported() bool {
+ return (reportOutput == nil) && Available()
+}
+
var globalData struct {
// files is the set of covered files sorted by filename. It is calculated at
// startup.
@@ -65,6 +84,9 @@ var globalData struct {
// syntheticPCs are a set of PCs calculated at startup, where the PC
// at syntheticPCs[i][j] corresponds to file i, block j.
syntheticPCs [][]uint64
+
+ // once ensures that globalData is only initialized once.
+ once sync.Once
}
// ClearCoverageData clears existing coverage data.
@@ -141,7 +163,7 @@ func ConsumeCoverageData(w io.Writer) int {
// Non-zero coverage data found; consume it and report as a PC.
counters[index] = 0
pc := globalData.syntheticPCs[fileNum][index]
- usermem.ByteOrder.PutUint64(pcBuffer[:], pc)
+ hostarch.ByteOrder.PutUint64(pcBuffer[:], pc)
n, err := w.Write(pcBuffer[:])
if err != nil {
if err == io.EOF {
@@ -166,7 +188,7 @@ func ConsumeCoverageData(w io.Writer) int {
// InitCoverageData initializes globalData. It should be called before any kcov
// data is written.
func InitCoverageData() {
- once.Do(func() {
+ globalData.once.Do(func() {
// First, order all files. Then calculate synthetic PCs for every block
// (using the well-defined ordering for files as well).
for file := range coverdata.Cover.Blocks {
@@ -185,6 +207,38 @@ func InitCoverageData() {
})
}
+// reportOnce ensures that a coverage report is written at most once. For a
+// complete coverage report, Report should be called during the sandbox teardown
+// process. Report is called from multiple places (which may overlap) so that a
+// coverage report is written in different sandbox exit scenarios.
+var reportOnce sync.Once
+
+// Report writes out a coverage report with all blocks that have been covered.
+//
+// TODO(b/144576401): Decide whether this should actually be in LCOV format
+func Report() error {
+ if reportOutput == nil {
+ return nil
+ }
+
+ var err error
+ reportOnce.Do(func() {
+ for file, counters := range coverdata.Cover.Counters {
+ blocks := coverdata.Cover.Blocks[file]
+ for i := 0; i < len(counters); i++ {
+ if atomic.LoadUint32(&counters[i]) > 0 {
+ err = writeBlock(reportOutput, file, blocks[i])
+ if err != nil {
+ return
+ }
+ }
+ }
+ }
+ reportOutput.Close()
+ })
+ return err
+}
+
// Symbolize prints information about the block corresponding to pc.
func Symbolize(out io.Writer, pc uint64) error {
fileNum, blockNum := syntheticPCToIndexes(pc)
@@ -196,18 +250,32 @@ func Symbolize(out io.Writer, pc uint64) error {
if err != nil {
return err
}
- writeBlock(out, pc, file, block)
- return nil
+ return writeBlockWithPC(out, pc, file, block)
}
// WriteAllBlocks prints all information about all blocks along with their
// corresponding synthetic PCs.
-func WriteAllBlocks(out io.Writer) {
+func WriteAllBlocks(out io.Writer) error {
for fileNum, file := range globalData.files {
for blockNum, block := range coverdata.Cover.Blocks[file] {
- writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block)
+ if err := writeBlockWithPC(out, calculateSyntheticPC(fileNum, blockNum), file, block); err != nil {
+ return err
+ }
}
}
+ return nil
+}
+
+func writeBlockWithPC(out io.Writer, pc uint64, file string, block testing.CoverBlock) error {
+ if _, err := io.WriteString(out, fmt.Sprintf("%#x\n", pc)); err != nil {
+ return err
+ }
+ return writeBlock(out, file, block)
+}
+
+func writeBlock(out io.Writer, file string, block testing.CoverBlock) error {
+ _, err := io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1))
+ return err
}
func calculateSyntheticPC(fileNum int, blockNum int) uint64 {
@@ -239,8 +307,3 @@ func blockFromIndex(file string, i int) (testing.CoverBlock, error) {
}
return blocks[i], nil
}
-
-func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) {
- io.WriteString(out, fmt.Sprintf("%#x\n", pc))
- io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1))
-}
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
index 35683fe98..b4e05f922 100644
--- a/pkg/gohacks/BUILD
+++ b/pkg/gohacks/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,3 +10,11 @@ go_library(
stateify = False,
visibility = ["//:sandbox"],
)
+
+go_test(
+ name = "gohacks_test",
+ size = "small",
+ srcs = ["gohacks_test.go"],
+ library = ":gohacks",
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/pkg/gohacks/gohacks_test.go b/pkg/gohacks/gohacks_test.go
new file mode 100644
index 000000000..e18c8abc7
--- /dev/null
+++ b/pkg/gohacks/gohacks_test.go
@@ -0,0 +1,97 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gohacks
+
+import (
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "runtime/debug"
+ "testing"
+
+ "golang.org/x/sys/unix"
+)
+
+func randBuf(size int) []byte {
+ b := make([]byte, size)
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular
+// dependency.
+const pageSize = 4096
+
+func testCopy(dst, src []byte) (panicked bool) {
+ defer func() {
+ if r := recover(); r != nil {
+ panicked = true
+ }
+ }()
+ debug.SetPanicOnFault(true)
+ copy(dst, src)
+ return panicked
+}
+
+func TestSegVOnMemmove(t *testing.T) {
+ // Test that SIGSEGVs received by runtime.memmove when *not* doing
+ // CopyIn or CopyOut work gets propagated to the runtime.
+ const bufLen = pageSize
+ a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+
+ }
+ defer unix.Munmap(a)
+ b := randBuf(bufLen)
+
+ if !testCopy(b, a) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+
+ if !testCopy(a, b) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+}
+
+func TestSigbusOnMemmove(t *testing.T) {
+ // Test that SIGBUS received by runtime.memmove when *not* doing
+ // CopyIn or CopyOut work gets propagated to the runtime.
+ const bufLen = pageSize
+ f, err := ioutil.TempFile("", "sigbus_test")
+ if err != nil {
+ t.Fatalf("TempFile failed: %v", err)
+ }
+ os.Remove(f.Name())
+ defer f.Close()
+
+ a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED)
+ if err != nil {
+ t.Fatalf("Mmap failed: %v", err)
+
+ }
+ defer unix.Munmap(a)
+ b := randBuf(bufLen)
+
+ if !testCopy(b, a) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+
+ if !testCopy(a, b) {
+ t.Fatalf("testCopy didn't panic when it should have")
+ }
+}
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
index 10bbb1f58..374aac2b4 100644
--- a/pkg/gohacks/gohacks_unsafe.go
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -75,3 +75,17 @@ func StringFromImmutableBytes(bs []byte) string {
// strings.Builder.String().
return *(*string)(unsafe.Pointer(&bs))
}
+
+// Note that go:linkname silently doesn't work if the local name is exported,
+// necessitating an indirection for exported functions.
+
+// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
+//
+//go:nosplit
+func Memmove(to, from unsafe.Pointer, n uintptr) {
+ memmove(to, from, n)
+}
+
+//go:linkname memmove runtime.memmove
+//go:noescape
+func memmove(to, from unsafe.Pointer, n uintptr)
diff --git a/pkg/hostarch/BUILD b/pkg/hostarch/BUILD
new file mode 100644
index 000000000..1e8def4d9
--- /dev/null
+++ b/pkg/hostarch/BUILD
@@ -0,0 +1,42 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "addr_range",
+ out = "addr_range.go",
+ package = "hostarch",
+ prefix = "Addr",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "Addr",
+ },
+)
+
+go_test(
+ name = "hostarch_test",
+ size = "small",
+ srcs = [
+ "addr_range_seq_test.go",
+ ],
+ library = ":hostarch",
+)
+
+go_library(
+ name = "hostarch",
+ srcs = [
+ "access_type.go",
+ "addr.go",
+ "addr_range.go",
+ "addr_range_seq_unsafe.go",
+ "hostarch.go",
+ "hostarch_arm64.go",
+ "hostarch_x86.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/gohacks",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/usermem/access_type.go b/pkg/hostarch/access_type.go
index 2cfca29af..e30476840 100644
--- a/pkg/usermem/access_type.go
+++ b/pkg/hostarch/access_type.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package hostarch
import "golang.org/x/sys/unix"
diff --git a/pkg/usermem/addr.go b/pkg/hostarch/addr.go
index c4100481e..0cf0f3c81 100644
--- a/pkg/usermem/addr.go
+++ b/pkg/hostarch/addr.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package hostarch
import (
"fmt"
@@ -57,7 +57,7 @@ func (v Addr) RoundUp() (addr Addr, ok bool) {
func (v Addr) MustRoundUp() Addr {
addr, ok := v.RoundUp()
if !ok {
- panic(fmt.Sprintf("usermem.Addr(%d).RoundUp() wraps", v))
+ panic(fmt.Sprintf("hostarch.Addr(%d).RoundUp() wraps", v))
}
return addr
}
diff --git a/pkg/usermem/addr_range_seq_test.go b/pkg/hostarch/addr_range_seq_test.go
index 82f735026..5726dfd19 100644
--- a/pkg/usermem/addr_range_seq_test.go
+++ b/pkg/hostarch/addr_range_seq_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package hostarch
import (
"testing"
diff --git a/pkg/usermem/addr_range_seq_unsafe.go b/pkg/hostarch/addr_range_seq_unsafe.go
index c9a1415a0..ecc17d595 100644
--- a/pkg/usermem/addr_range_seq_unsafe.go
+++ b/pkg/hostarch/addr_range_seq_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package hostarch
import (
"bytes"
diff --git a/pkg/hostarch/hostarch.go b/pkg/hostarch/hostarch.go
new file mode 100644
index 000000000..fdd29c567
--- /dev/null
+++ b/pkg/hostarch/hostarch.go
@@ -0,0 +1,7 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package hostarch contains host arch address operations for user memory.
+package hostarch
diff --git a/pkg/usermem/usermem_arm64.go b/pkg/hostarch/hostarch_arm64.go
index 7e7529585..a31a8aeeb 100644
--- a/pkg/usermem/usermem_arm64.go
+++ b/pkg/hostarch/hostarch_arm64.go
@@ -14,7 +14,7 @@
// +build arm64
-package usermem
+package hostarch
import (
"encoding/binary"
diff --git a/pkg/usermem/usermem_x86.go b/pkg/hostarch/hostarch_x86.go
index d96f829fb..af6ef2b7f 100644
--- a/pkg/usermem/usermem_x86.go
+++ b/pkg/hostarch/hostarch_x86.go
@@ -14,7 +14,7 @@
// +build amd64 386
-package usermem
+package hostarch
import "encoding/binary"
diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD
index aac0161fa..7cd89e639 100644
--- a/pkg/marshal/BUILD
+++ b/pkg/marshal/BUILD
@@ -11,5 +11,5 @@ go_library(
visibility = [
"//:sandbox",
],
- deps = ["//pkg/usermem"],
+ deps = ["//pkg/hostarch"],
)
diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go
index d8cb44b40..7da450ce8 100644
--- a/pkg/marshal/marshal.go
+++ b/pkg/marshal/marshal.go
@@ -23,7 +23,7 @@ package marshal
import (
"io"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// CopyContext defines the memory operations required to marshal to and from
@@ -36,11 +36,11 @@ type CopyContext interface {
// CopyOutBytes writes the contents of b to the task's memory. See
// kernel.CopyOutBytes.
- CopyOutBytes(addr usermem.Addr, b []byte) (int, error)
+ CopyOutBytes(addr hostarch.Addr, b []byte) (int, error)
// CopyInBytes reads the contents of the task's memory to b. See
// kernel.CopyInBytes.
- CopyInBytes(addr usermem.Addr, b []byte) (int, error)
+ CopyInBytes(addr hostarch.Addr, b []byte) (int, error)
}
// Marshallable represents operations on a type that can be marshalled to and
@@ -108,7 +108,7 @@ type Marshallable interface {
// If the copy-in from the task memory is only partially successful, CopyIn
// should still attempt to deserialize as much data as possible. See comment
// for UnmarshalBytes.
- CopyIn(cc CopyContext, addr usermem.Addr) (int, error)
+ CopyIn(cc CopyContext, addr hostarch.Addr) (int, error)
// CopyOut serializes a Marshallable type to a task's memory. This may only
// be called from a task goroutine. This is more efficient than calling
@@ -119,7 +119,7 @@ type Marshallable interface {
// The copy-out to the task memory may be partially successful, in which
// case CopyOut returns how much data was serialized. See comment for
// MarshalBytes for implications.
- CopyOut(cc CopyContext, addr usermem.Addr) (int, error)
+ CopyOut(cc CopyContext, addr hostarch.Addr) (int, error)
// CopyOutN is like CopyOut, but explicitly requests a partial
// copy-out. Note that this may yield unexpected results for non-packed
@@ -127,7 +127,7 @@ type Marshallable interface {
// comment on MarshalBytes.
//
// The limit must be less than or equal to SizeBytes().
- CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error)
+ CopyOutN(cc CopyContext, addr hostarch.Addr, limit int) (int, error)
}
// go-marshal generates additional functions for a type based on additional
@@ -157,15 +157,18 @@ type Marshallable interface {
// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... }
//
// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory.
-// func CopyFooSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Foo) (int, error) { ... }
+// func CopyFooSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []Foo) (int, error) { ... }
//
// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory.
-// func CopyFooSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []Foo) (int, error) { ... }
+// func CopyFooSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []Foo) (int, error) { ... }
//
// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where
// %s is the first argument to the slice clause. This directive is not supported
// for newtypes on arrays.
//
+// Note: Partial copies are not supported for Slice API UnmarshalUnsafe and
+// MarshalUnsafe.
+//
// The slice clause also takes an optional second argument, which must be the
// value "inner":
//
@@ -175,10 +178,10 @@ type Marshallable interface {
// This is only valid on newtypes on primitives, and causes the generated
// functions to accept slices of the inner type instead:
//
-// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []int32) (int, error) { ... }
+// func CopyInt32SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []int32) (int, error) { ... }
//
// Without "inner", they would instead be:
//
-// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Int32) (int, error) { ... }
+// func CopyInt32SliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst []Int32) (int, error) { ... }
//
// This may help avoid a cast depending on how the generated functions are used.
diff --git a/pkg/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go
index ea75e09f2..9e6a6fa29 100644
--- a/pkg/marshal/marshal_impl_util.go
+++ b/pkg/marshal/marshal_impl_util.go
@@ -17,7 +17,7 @@ package marshal
import (
"io"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// StubMarshallable implements the Marshallable interface.
@@ -63,16 +63,16 @@ func (StubMarshallable) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements Marshallable.CopyIn.
-func (StubMarshallable) CopyIn(cc CopyContext, addr usermem.Addr) (int, error) {
+func (StubMarshallable) CopyIn(cc CopyContext, addr hostarch.Addr) (int, error) {
panic("Please implement your own CopyIn function")
}
// CopyOut implements Marshallable.CopyOut.
-func (StubMarshallable) CopyOut(cc CopyContext, addr usermem.Addr) (int, error) {
+func (StubMarshallable) CopyOut(cc CopyContext, addr hostarch.Addr) (int, error) {
panic("Please implement your own CopyOut function")
}
// CopyOutN implements Marshallable.CopyOutN.
-func (StubMarshallable) CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) {
+func (StubMarshallable) CopyOutN(cc CopyContext, addr hostarch.Addr, limit int) (int, error) {
panic("Please implement your own CopyOutN function")
}
diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD
index d77a11c79..190b57c29 100644
--- a/pkg/marshal/primitive/BUILD
+++ b/pkg/marshal/primitive/BUILD
@@ -13,6 +13,7 @@ go_library(
],
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/marshal",
"//pkg/usermem",
],
diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
index 4b342de6b..32c8ed138 100644
--- a/pkg/marshal/primitive/primitive.go
+++ b/pkg/marshal/primitive/primitive.go
@@ -20,6 +20,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -102,17 +103,17 @@ func (b *ByteSlice) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
return cc.CopyInBytes(addr, *b)
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
return cc.CopyOutBytes(addr, *b)
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
+func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {
return cc.CopyOutBytes(addr, (*b)[:limit])
}
@@ -131,7 +132,7 @@ var _ marshal.Marshallable = (*ByteSlice)(nil)
// CopyInt8In is a convenient wrapper for copying in an int8 from the task's
// memory.
-func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, error) {
+func CopyInt8In(cc marshal.CopyContext, addr hostarch.Addr, dst *int8) (int, error) {
var buf Int8
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -143,14 +144,14 @@ func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, erro
// CopyInt8Out is a convenient wrapper for copying out an int8 to the task's
// memory.
-func CopyInt8Out(cc marshal.CopyContext, addr usermem.Addr, src int8) (int, error) {
+func CopyInt8Out(cc marshal.CopyContext, addr hostarch.Addr, src int8) (int, error) {
srcP := Int8(src)
return srcP.CopyOut(cc, addr)
}
// CopyUint8In is a convenient wrapper for copying in a uint8 from the task's
// memory.
-func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, error) {
+func CopyUint8In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint8) (int, error) {
var buf Uint8
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -162,7 +163,7 @@ func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, er
// CopyUint8Out is a convenient wrapper for copying out a uint8 to the task's
// memory.
-func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, error) {
+func CopyUint8Out(cc marshal.CopyContext, addr hostarch.Addr, src uint8) (int, error) {
srcP := Uint8(src)
return srcP.CopyOut(cc, addr)
}
@@ -171,7 +172,7 @@ func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, er
// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
// memory.
-func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, error) {
+func CopyInt16In(cc marshal.CopyContext, addr hostarch.Addr, dst *int16) (int, error) {
var buf Int16
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -183,14 +184,14 @@ func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, er
// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's
// memory.
-func CopyInt16Out(cc marshal.CopyContext, addr usermem.Addr, src int16) (int, error) {
+func CopyInt16Out(cc marshal.CopyContext, addr hostarch.Addr, src int16) (int, error) {
srcP := Int16(src)
return srcP.CopyOut(cc, addr)
}
// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's
// memory.
-func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, error) {
+func CopyUint16In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint16) (int, error) {
var buf Uint16
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -202,7 +203,7 @@ func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int,
// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's
// memory.
-func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, error) {
+func CopyUint16Out(cc marshal.CopyContext, addr hostarch.Addr, src uint16) (int, error) {
srcP := Uint16(src)
return srcP.CopyOut(cc, addr)
}
@@ -211,7 +212,7 @@ func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int,
// CopyInt32In is a convenient wrapper for copying in an int32 from the task's
// memory.
-func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, error) {
+func CopyInt32In(cc marshal.CopyContext, addr hostarch.Addr, dst *int32) (int, error) {
var buf Int32
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -223,14 +224,14 @@ func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, er
// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's
// memory.
-func CopyInt32Out(cc marshal.CopyContext, addr usermem.Addr, src int32) (int, error) {
+func CopyInt32Out(cc marshal.CopyContext, addr hostarch.Addr, src int32) (int, error) {
srcP := Int32(src)
return srcP.CopyOut(cc, addr)
}
// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's
// memory.
-func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, error) {
+func CopyUint32In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint32) (int, error) {
var buf Uint32
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -242,7 +243,7 @@ func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int,
// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's
// memory.
-func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, error) {
+func CopyUint32Out(cc marshal.CopyContext, addr hostarch.Addr, src uint32) (int, error) {
srcP := Uint32(src)
return srcP.CopyOut(cc, addr)
}
@@ -251,7 +252,7 @@ func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int,
// CopyInt64In is a convenient wrapper for copying in an int64 from the task's
// memory.
-func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, error) {
+func CopyInt64In(cc marshal.CopyContext, addr hostarch.Addr, dst *int64) (int, error) {
var buf Int64
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -263,14 +264,14 @@ func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, er
// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's
// memory.
-func CopyInt64Out(cc marshal.CopyContext, addr usermem.Addr, src int64) (int, error) {
+func CopyInt64Out(cc marshal.CopyContext, addr hostarch.Addr, src int64) (int, error) {
srcP := Int64(src)
return srcP.CopyOut(cc, addr)
}
// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's
// memory.
-func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, error) {
+func CopyUint64In(cc marshal.CopyContext, addr hostarch.Addr, dst *uint64) (int, error) {
var buf Uint64
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -282,14 +283,14 @@ func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int,
// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's
// memory.
-func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, error) {
+func CopyUint64Out(cc marshal.CopyContext, addr hostarch.Addr, src uint64) (int, error) {
srcP := Uint64(src)
return srcP.CopyOut(cc, addr)
}
// CopyByteSliceIn is a convenient wrapper for copying in a []byte from the
// task's memory.
-func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (int, error) {
+func CopyByteSliceIn(cc marshal.CopyContext, addr hostarch.Addr, dst *[]byte) (int, error) {
var buf ByteSlice
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -301,14 +302,14 @@ func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (in
// CopyByteSliceOut is a convenient wrapper for copying out a []byte to the
// task's memory.
-func CopyByteSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []byte) (int, error) {
+func CopyByteSliceOut(cc marshal.CopyContext, addr hostarch.Addr, src []byte) (int, error) {
srcP := ByteSlice(src)
return srcP.CopyOut(cc, addr)
}
// CopyStringIn is a convenient wrapper for copying in a string from the
// task's memory.
-func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, error) {
+func CopyStringIn(cc marshal.CopyContext, addr hostarch.Addr, dst *string) (int, error) {
var buf ByteSlice
n, err := buf.CopyIn(cc, addr)
if err != nil {
@@ -320,12 +321,12 @@ func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int,
// CopyStringOut is a convenient wrapper for copying out a string to the task's
// memory.
-func CopyStringOut(cc marshal.CopyContext, addr usermem.Addr, src string) (int, error) {
+func CopyStringOut(cc marshal.CopyContext, addr hostarch.Addr, src string) (int, error) {
srcP := ByteSlice(src)
return srcP.CopyOut(cc, addr)
}
-// IOCopyContext wraps an object implementing usermem.IO to implement
+// IOCopyContext wraps an object implementing hostarch.IO to implement
// marshal.CopyContext.
type IOCopyContext struct {
Ctx context.Context
@@ -339,11 +340,11 @@ func (i *IOCopyContext) CopyScratchBuffer(size int) []byte {
}
// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
-func (i *IOCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+func (i *IOCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) {
return i.IO.CopyOut(i.Ctx, addr, b, i.Opts)
}
// CopyInBytes implements marshal.CopyContext.CopyInBytes.
-func (i *IOCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+func (i *IOCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) {
return i.IO.CopyIn(i.Ctx, addr, b, i.Opts)
}
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
index 501a9ef21..dcd6c3bf5 100644
--- a/pkg/merkletree/BUILD
+++ b/pkg/merkletree/BUILD
@@ -8,7 +8,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/usermem",
+ "//pkg/hostarch",
],
)
@@ -18,6 +18,6 @@ go_test(
library = ":merkletree",
deps = [
"//pkg/abi/linux",
- "//pkg/usermem",
+ "//pkg/hostarch",
],
)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index d7209ace3..6450f664c 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -24,7 +24,8 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -65,7 +66,7 @@ type Layout struct {
// of a tree. dataSize specifies the size of input data in bytes.
func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) {
layout := Layout{
- blockSize: usermem.PageSize,
+ blockSize: hostarch.PageSize,
}
// TODO(b/156980949): Allow config SHA384.
@@ -237,6 +238,7 @@ func Generate(params *GenerateParams) ([]byte, error) {
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
+ Children: params.Children,
SymlinkTarget: params.SymlinkTarget,
}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index ed332b3f1..5d6f8df1b 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -24,7 +24,8 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
func TestLayout(t *testing.T) {
@@ -58,7 +59,7 @@ func TestLayout(t *testing.T) {
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
expectedDigestSize: 32,
- expectedLevelOffset: []int64{usermem.PageSize},
+ expectedLevelOffset: []int64{hostarch.PageSize},
},
{
name: "SmallSizeSHA512SameFile",
@@ -66,7 +67,7 @@ func TestLayout(t *testing.T) {
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
expectedDigestSize: 64,
- expectedLevelOffset: []int64{usermem.PageSize},
+ expectedLevelOffset: []int64{hostarch.PageSize},
},
{
name: "MiddleSizeSHA256SeparateFile",
@@ -74,7 +75,7 @@ func TestLayout(t *testing.T) {
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
expectedDigestSize: 32,
- expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize},
+ expectedLevelOffset: []int64{0, 2 * hostarch.PageSize, 3 * hostarch.PageSize},
},
{
name: "MiddleSizeSHA512SeparateFile",
@@ -82,7 +83,7 @@ func TestLayout(t *testing.T) {
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
expectedDigestSize: 64,
- expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize},
+ expectedLevelOffset: []int64{0, 4 * hostarch.PageSize, 5 * hostarch.PageSize},
},
{
name: "MiddleSizeSHA256SameFile",
@@ -90,7 +91,7 @@ func TestLayout(t *testing.T) {
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
expectedDigestSize: 32,
- expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize},
+ expectedLevelOffset: []int64{245 * hostarch.PageSize, 247 * hostarch.PageSize, 248 * hostarch.PageSize},
},
{
name: "MiddleSizeSHA512SameFile",
@@ -98,39 +99,39 @@ func TestLayout(t *testing.T) {
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
expectedDigestSize: 64,
- expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize},
+ expectedLevelOffset: []int64{245 * hostarch.PageSize, 249 * hostarch.PageSize, 250 * hostarch.PageSize},
},
{
name: "LargeSizeSHA256SeparateFile",
- dataSize: 4096 * int64(usermem.PageSize),
+ dataSize: 4096 * int64(hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
expectedDigestSize: 32,
- expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize},
+ expectedLevelOffset: []int64{0, 32 * hostarch.PageSize, 33 * hostarch.PageSize},
},
{
name: "LargeSizeSHA512SeparateFile",
- dataSize: 4096 * int64(usermem.PageSize),
+ dataSize: 4096 * int64(hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
expectedDigestSize: 64,
- expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize},
+ expectedLevelOffset: []int64{0, 64 * hostarch.PageSize, 65 * hostarch.PageSize},
},
{
name: "LargeSizeSHA256SameFile",
- dataSize: 4096 * int64(usermem.PageSize),
+ dataSize: 4096 * int64(hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
expectedDigestSize: 32,
- expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize},
+ expectedLevelOffset: []int64{4096 * hostarch.PageSize, 4128 * hostarch.PageSize, 4129 * hostarch.PageSize},
},
{
name: "LargeSizeSHA512SameFile",
- dataSize: 4096 * int64(usermem.PageSize),
+ dataSize: 4096 * int64(hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
expectedDigestSize: 64,
- expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize},
+ expectedLevelOffset: []int64{4096 * hostarch.PageSize, 4160 * hostarch.PageSize, 4161 * hostarch.PageSize},
},
}
@@ -140,8 +141,8 @@ func TestLayout(t *testing.T) {
if err != nil {
t.Fatalf("Failed to InitLayout: %v", err)
}
- if l.blockSize != int64(usermem.PageSize) {
- t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize)
+ if l.blockSize != int64(hostarch.PageSize) {
+ t.Errorf("Got blockSize %d, want %d", l.blockSize, hostarch.PageSize)
}
if l.digestSize != tc.expectedDigestSize {
t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize)
@@ -202,56 +203,56 @@ func TestGenerate(t *testing.T) {
}{
{
name: "OnePageZeroesSHA256SeparateFile",
- data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27},
},
{
name: "OnePageZeroesSHA256SameFile",
- data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
expectedHash: []byte{9, 115, 238, 230, 38, 140, 195, 70, 207, 144, 202, 118, 23, 113, 32, 129, 226, 239, 177, 69, 161, 26, 14, 113, 16, 37, 30, 96, 19, 148, 132, 27},
},
{
name: "OnePageZeroesSHA512SeparateFile",
- data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14},
},
{
name: "OnePageZeroesSHA512SameFile",
- data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ data: bytes.Repeat([]byte{0}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
expectedHash: []byte{127, 8, 95, 11, 83, 101, 51, 39, 170, 235, 39, 43, 135, 243, 145, 118, 148, 58, 27, 155, 182, 205, 44, 47, 5, 223, 215, 17, 35, 16, 43, 104, 43, 11, 8, 88, 171, 7, 249, 243, 14, 62, 126, 218, 23, 159, 237, 237, 42, 226, 39, 25, 87, 48, 253, 191, 116, 213, 37, 3, 187, 152, 154, 14},
},
{
name: "MultiplePageZeroesSHA256SeparateFile",
- data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185},
},
{
name: "MultiplePageZeroesSHA256SameFile",
- data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
expectedHash: []byte{247, 158, 42, 215, 180, 106, 0, 28, 77, 64, 132, 162, 74, 65, 250, 161, 243, 66, 129, 44, 197, 8, 145, 14, 94, 206, 156, 184, 145, 145, 20, 185},
},
{
name: "MultiplePageZeroesSHA512SeparateFile",
- data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81},
},
{
name: "MultiplePageZeroesSHA512SameFile",
- data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ data: bytes.Repeat([]byte{0}, 128*hostarch.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
expectedHash: []byte{100, 121, 14, 30, 104, 200, 142, 182, 190, 78, 23, 68, 157, 174, 23, 75, 174, 250, 250, 25, 66, 45, 235, 103, 129, 49, 78, 127, 173, 154, 121, 35, 37, 115, 60, 217, 26, 205, 253, 253, 236, 145, 107, 109, 232, 19, 72, 92, 4, 191, 181, 205, 191, 57, 234, 177, 144, 235, 143, 30, 15, 197, 109, 81},
@@ -286,28 +287,28 @@ func TestGenerate(t *testing.T) {
},
{
name: "OnePageASHA256SeparateFile",
- data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1},
},
{
name: "OnePageASHA256SameFile",
- data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
expectedHash: []byte{132, 54, 112, 142, 156, 19, 50, 140, 138, 240, 192, 154, 100, 120, 242, 69, 64, 217, 62, 166, 127, 88, 23, 197, 100, 66, 255, 215, 214, 229, 54, 1},
},
{
name: "OnePageASHA512SeparateFile",
- data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: false,
expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145},
},
{
name: "OnePageASHA512SameFile",
- data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ data: bytes.Repeat([]byte{'a'}, hostarch.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
expectedHash: []byte{165, 46, 176, 116, 47, 209, 101, 193, 64, 185, 30, 9, 52, 22, 24, 154, 135, 220, 232, 168, 215, 45, 222, 226, 207, 104, 160, 10, 156, 98, 245, 250, 76, 21, 68, 204, 65, 118, 69, 52, 210, 155, 36, 109, 233, 103, 1, 40, 218, 89, 125, 38, 247, 194, 2, 225, 119, 155, 65, 99, 182, 111, 110, 145},
@@ -415,14 +416,14 @@ func TestVerifyInvalidRange(t *testing.T) {
// Verify range starts outside data range.
{
name: "StartOutsideRange",
- verifyStart: usermem.PageSize,
+ verifyStart: hostarch.PageSize,
verifySize: 1,
},
// Verify range ends outside data range.
{
name: "EndOutsideRange",
verifyStart: 0,
- verifySize: 2 * usermem.PageSize,
+ verifySize: 2 * hostarch.PageSize,
},
// Verify range with negative size.
{
@@ -434,7 +435,7 @@ func TestVerifyInvalidRange(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf)
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
}
@@ -467,7 +468,7 @@ func TestVerifyUnmodifiedMetadata(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, tc.isSymlink, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, tc.isSymlink, 0 /* verifyStart */, 0 /* verifySize */, &buf)
if tc.isSymlink {
params.SymlinkTarget = defaultSymlinkPath
}
@@ -495,7 +496,7 @@ func TestVerifyModifiedName(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
params.Name += "abc"
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
@@ -521,7 +522,7 @@ func TestVerifyModifiedSize(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
params.Size--
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
@@ -547,7 +548,7 @@ func TestVerifyModifiedMode(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
params.Mode++
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
@@ -573,7 +574,7 @@ func TestVerifyModifiedUID(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
params.UID++
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
@@ -599,7 +600,7 @@ func TestVerifyModifiedGID(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
params.GID++
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
@@ -625,7 +626,7 @@ func TestVerifyModifiedChildren(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
params.Children["abc"] = struct{}{}
if _, err := Verify(&params); errors.Is(err, nil) {
t.Errorf("Verification succeeded when expected to fail")
@@ -636,7 +637,7 @@ func TestVerifyModifiedChildren(t *testing.T) {
func TestVerifyModifiedSymlink(t *testing.T) {
var buf bytes.Buffer
- _, params := prepareVerify(t, usermem.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, true /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
+ _, params := prepareVerify(t, hostarch.PageSize /* dataSize */, defaultHashAlgorithm, false /* dataAndTreeInSameFile */, true /* isSymlink */, 0 /* verifyStart */, 0 /* verifySize */, &buf)
params.SymlinkTarget = "merkle_modified_test_link"
if _, err := Verify(&params); err == nil {
t.Errorf("Verification succeeded when expected to fail")
@@ -652,30 +653,30 @@ func TestModifyOutsideVerifyRange(t *testing.T) {
}{
{
name: "BeforeRangeSeparateFile",
- modifyByte: 4*usermem.PageSize - 1,
+ modifyByte: 4*hostarch.PageSize - 1,
dataAndTreeInSameFile: false,
},
{
name: "BeforeRangeSameFile",
- modifyByte: 4*usermem.PageSize - 1,
+ modifyByte: 4*hostarch.PageSize - 1,
dataAndTreeInSameFile: true,
},
{
name: "AfterRangeSeparateFile",
- modifyByte: 5 * usermem.PageSize,
+ modifyByte: 5 * hostarch.PageSize,
dataAndTreeInSameFile: false,
},
{
name: "AfterRangeSameFile",
- modifyByte: 5 * usermem.PageSize,
+ modifyByte: 5 * hostarch.PageSize,
dataAndTreeInSameFile: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- dataSize := int64(8 * usermem.PageSize)
- verifyStart := int64(4 * usermem.PageSize)
- verifySize := int64(usermem.PageSize)
+ dataSize := int64(8 * hostarch.PageSize)
+ verifyStart := int64(4 * hostarch.PageSize)
+ verifySize := int64(hostarch.PageSize)
var buf bytes.Buffer
// Modified byte is outside verify range. Verify should succeed.
data, params := prepareVerify(t, dataSize, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, verifyStart, verifySize, &buf)
@@ -712,16 +713,16 @@ func TestModifyInsideVerifyRange(t *testing.T) {
// to fail.
{
name: "BlockAlignedRangeSeparateFile",
- verifyStart: 4 * usermem.PageSize,
- verifySize: usermem.PageSize,
- modifyByte: 4 * usermem.PageSize,
+ verifyStart: 4 * hostarch.PageSize,
+ verifySize: hostarch.PageSize,
+ modifyByte: 4 * hostarch.PageSize,
dataAndTreeInSameFile: false,
},
{
name: "BlockAlignedRangeSameFile",
- verifyStart: 4 * usermem.PageSize,
- verifySize: usermem.PageSize,
- modifyByte: 4 * usermem.PageSize,
+ verifyStart: 4 * hostarch.PageSize,
+ verifySize: hostarch.PageSize,
+ modifyByte: 4 * hostarch.PageSize,
dataAndTreeInSameFile: true,
},
// The tests below use a non-block-aligned verify range.
@@ -729,48 +730,48 @@ func TestModifyInsideVerifyRange(t *testing.T) {
// verify to fail.
{
name: "ModifyStartSeparateFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 4*usermem.PageSize + 123,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 4*hostarch.PageSize + 123,
dataAndTreeInSameFile: false,
},
{
name: "ModifyStartSameFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 4*usermem.PageSize + 123,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 4*hostarch.PageSize + 123,
dataAndTreeInSameFile: true,
},
// Modifying a byte at the end of verify range should cause
// verify to fail.
{
name: "ModifyEndSeparateFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 6*usermem.PageSize + 123,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 6*hostarch.PageSize + 123,
dataAndTreeInSameFile: false,
},
{
name: "ModifyEndSameFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 6*usermem.PageSize + 123,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 6*hostarch.PageSize + 123,
dataAndTreeInSameFile: true,
},
// Modifying a byte in the middle verified block should cause
// verify to fail.
{
name: "ModifyMiddleSeparateFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 5*usermem.PageSize + 123,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 5*hostarch.PageSize + 123,
dataAndTreeInSameFile: false,
},
{
name: "ModifyMiddleSameFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 5*usermem.PageSize + 123,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 5*hostarch.PageSize + 123,
dataAndTreeInSameFile: true,
},
// Modifying a byte in the first block in the verified range
@@ -778,16 +779,16 @@ func TestModifyInsideVerifyRange(t *testing.T) {
// out of verify range.
{
name: "ModifyFirstBlockSeparateFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 4*usermem.PageSize + 122,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 4*hostarch.PageSize + 122,
dataAndTreeInSameFile: false,
},
{
name: "ModifyFirstBlockSameFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 4*usermem.PageSize + 122,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 4*hostarch.PageSize + 122,
dataAndTreeInSameFile: true,
},
// Modifying a byte in the last block in the verified range
@@ -795,22 +796,22 @@ func TestModifyInsideVerifyRange(t *testing.T) {
// out of verify range.
{
name: "ModifyLastBlockSeparateFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 6*usermem.PageSize + 124,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 6*hostarch.PageSize + 124,
dataAndTreeInSameFile: false,
},
{
name: "ModifyLastBlockSameFile",
- verifyStart: 4*usermem.PageSize + 123,
- verifySize: 2 * usermem.PageSize,
- modifyByte: 6*usermem.PageSize + 124,
+ verifyStart: 4*hostarch.PageSize + 123,
+ verifySize: 2 * hostarch.PageSize,
+ modifyByte: 6*hostarch.PageSize + 124,
dataAndTreeInSameFile: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- dataSize := int64(8 * usermem.PageSize)
+ dataSize := int64(8 * hostarch.PageSize)
var buf bytes.Buffer
data, params := prepareVerify(t, dataSize, defaultHashAlgorithm, tc.dataAndTreeInSameFile, false /* isSymlink */, tc.verifyStart, tc.verifySize, &buf)
// Flip a bit in data and checks Verify results.
@@ -854,7 +855,7 @@ func TestVerifyRandom(t *testing.T) {
rand.Seed(time.Now().UnixNano())
// Use a random dataSize. Minimum size 2 so that we can pick a random
// portion from it.
- dataSize := rand.Int63n(200*usermem.PageSize) + 2
+ dataSize := rand.Int63n(200*hostarch.PageSize) + 2
// Pick a random portion of data.
start := rand.Int63n(dataSize - 1)
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index c9f9357de..2a2f0d611 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -38,7 +38,7 @@ var (
)
// Uint64Metric encapsulates a uint64 that represents some kind of metric to be
-// monitored.
+// monitored. We currently support metrics with at most one field.
//
// Metrics are not saved across save/restore and thus reset to zero on restore.
//
@@ -46,6 +46,16 @@ var (
type Uint64Metric struct {
// value is the actual value of the metric. It must be accessed atomically.
value uint64
+
+ // numFields is the number of metric fields. It is immutable once
+ // initialized.
+ numFields int
+
+ // mu protects the below fields.
+ mu sync.RWMutex `state:"nosave"`
+
+ // fields is the map of fields in the metric.
+ fields map[string]uint64
}
var (
@@ -97,8 +107,19 @@ type customUint64Metric struct {
// metadata describes the metric. It is immutable.
metadata *pb.MetricMetadata
- // value returns the current value of the metric.
- value func() uint64
+ // value returns the current value of the metric for the given set of
+ // fields. It takes a variadic number of field values as argument.
+ value func(fieldValues ...string) uint64
+}
+
+// Field contains the field name and allowed values for the metric which is
+// used in registration of the metric.
+type Field struct {
+ // name is the metric field name.
+ name string
+
+ // allowedValues is the list of allowed values for the field.
+ allowedValues []string
}
// RegisterCustomUint64Metric registers a metric with the given name.
@@ -109,7 +130,8 @@ type customUint64Metric struct {
// Preconditions:
// * name must be globally unique.
// * Initialize/Disable have not been called.
-func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error {
+// * value is expected to accept exactly len(fields) arguments.
+func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func(...string) uint64, fields ...Field) error {
if initialized {
return ErrInitializationDone
}
@@ -129,13 +151,25 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met
},
value: value,
}
+
+ // Metrics can exist without fields.
+ if len(fields) > 1 {
+ panic("Sentry metrics support at most one field")
+ }
+
+ for _, field := range fields {
+ allMetrics.m[name].metadata.Fields = append(allMetrics.m[name].metadata.Fields, &pb.MetricMetadata_Field{
+ FieldName: field.name,
+ AllowedValues: field.allowedValues,
+ })
+ }
return nil
}
-// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics
-// if it returns an error.
-func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func() uint64) {
- if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value); err != nil {
+// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric for metrics
+// without fields and panics if it returns an error.
+func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) {
+ if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil {
panic(fmt.Sprintf("Unable to register metric %q: %v", name, err))
}
}
@@ -144,15 +178,24 @@ func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, descript
// name.
//
// Metrics must be statically defined (i.e., at init).
-func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string) (*Uint64Metric, error) {
- var m Uint64Metric
- return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value)
+func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string, fields ...Field) (*Uint64Metric, error) {
+ m := Uint64Metric{
+ numFields: len(fields),
+ }
+
+ if m.numFields == 1 {
+ m.fields = make(map[string]uint64)
+ for _, fieldValue := range fields[0].allowedValues {
+ m.fields[fieldValue] = 0
+ }
+ }
+ return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value, fields...)
}
// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an
// error.
-func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric {
- m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description)
+func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric {
+ m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...)
if err != nil {
panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
}
@@ -169,19 +212,56 @@ func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description st
return m
}
-// Value returns the current value of the metric.
-func (m *Uint64Metric) Value() uint64 {
- return atomic.LoadUint64(&m.value)
+// Value returns the current value of the metric for the given set of fields.
+func (m *Uint64Metric) Value(fieldValues ...string) uint64 {
+ if m.numFields != len(fieldValues) {
+ panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields))
+ }
+
+ switch m.numFields {
+ case 0:
+ return atomic.LoadUint64(&m.value)
+ case 1:
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ fieldValue := fieldValues[0]
+ if _, ok := m.fields[fieldValue]; !ok {
+ panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue))
+ }
+ return m.fields[fieldValue]
+ default:
+ panic("Sentry metrics do not support more than one field")
+ }
}
-// Increment increments the metric by 1.
-func (m *Uint64Metric) Increment() {
- atomic.AddUint64(&m.value, 1)
+// Increment increments the metric field by 1.
+func (m *Uint64Metric) Increment(fieldValues ...string) {
+ m.IncrementBy(1, fieldValues...)
}
// IncrementBy increments the metric by v.
-func (m *Uint64Metric) IncrementBy(v uint64) {
- atomic.AddUint64(&m.value, v)
+func (m *Uint64Metric) IncrementBy(v uint64, fieldValues ...string) {
+ if m.numFields != len(fieldValues) {
+ panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields))
+ }
+
+ switch m.numFields {
+ case 0:
+ atomic.AddUint64(&m.value, v)
+ return
+ case 1:
+ fieldValue := fieldValues[0]
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if _, ok := m.fields[fieldValue]; !ok {
+ panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue))
+ }
+ m.fields[fieldValue] += v
+ default:
+ panic("Sentry metrics do not support more than one field")
+ }
}
// metricSet holds named metrics.
@@ -199,14 +279,30 @@ func makeMetricSet() metricSet {
// Values returns a snapshot of all values in m.
func (m *metricSet) Values() metricValues {
vals := make(metricValues)
+
for k, v := range m.m {
- vals[k] = v.value()
+ fields := v.metadata.GetFields()
+ switch len(fields) {
+ case 0:
+ vals[k] = v.value()
+ case 1:
+ values := fields[0].GetAllowedValues()
+ fieldsMap := make(map[string]uint64)
+ for _, fieldValue := range values {
+ fieldsMap[fieldValue] = v.value(fieldValue)
+ }
+ vals[k] = fieldsMap
+ default:
+ panic(fmt.Sprintf("Unsupported number of metric fields: %d", len(fields)))
+ }
}
return vals
}
-// metricValues contains a copy of the values of all metrics.
-type metricValues map[string]uint64
+// metricValues contains a copy of the values of all metrics. It is a map
+// with key as metric name and value can be either uint64 or map[string]uint64
+// to support metrics with one field.
+type metricValues map[string]interface{}
var (
// emitMu protects metricsAtLastEmit and ensures that all emitted
@@ -233,14 +329,37 @@ func EmitMetricUpdate() {
snapshot := allMetrics.Values()
m := pb.MetricUpdate{}
+ // On the first call metricsAtLastEmit will be empty. Include all
+ // metrics then.
for k, v := range snapshot {
- // On the first call metricsAtLastEmit will be empty. Include
- // all metrics then.
- if prev, ok := metricsAtLastEmit[k]; !ok || prev != v {
+ prev, ok := metricsAtLastEmit[k]
+ switch t := v.(type) {
+ case uint64:
+ // Metric exists and value did not change.
+ if ok && prev.(uint64) == t {
+ continue
+ }
+
m.Metrics = append(m.Metrics, &pb.MetricValue{
Name: k,
- Value: &pb.MetricValue_Uint64Value{v},
+ Value: &pb.MetricValue_Uint64Value{t},
})
+ case map[string]uint64:
+ for fieldValue, metricValue := range t {
+ // Emit data on the first call only if the field
+ // value has been incremented. For all other
+ // calls, emit data if the field value has been
+ // changed from the previous emit.
+ if (!ok && metricValue == 0) || (ok && prev.(map[string]uint64)[fieldValue] == metricValue) {
+ continue
+ }
+
+ m.Metrics = append(m.Metrics, &pb.MetricValue{
+ Name: k,
+ FieldValues: []string{fieldValue},
+ Value: &pb.MetricValue_Uint64Value{metricValue},
+ })
+ }
}
}
diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto
index 3cc89047d..53c8b4b50 100644
--- a/pkg/metric/metric.proto
+++ b/pkg/metric/metric.proto
@@ -48,6 +48,15 @@ message MetricMetadata {
// units is the units of the metric value.
Units units = 6;
+
+ message Field {
+ string field_name = 1;
+ repeated string allowed_values = 2;
+ }
+
+ // fields contains the metric fields. Currently a metric can have at most
+ // one field.
+ repeated Field fields = 7;
}
// MetricRegistration contains the metadata for all metrics that will be in
@@ -66,6 +75,8 @@ message MetricValue {
oneof value {
uint64 uint64_value = 2;
}
+
+ repeated string field_values = 4;
}
// MetricUpdate contains new values for multiple distinct metrics.
diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go
index aefd0ea5c..c71dfd460 100644
--- a/pkg/metric/metric_test.go
+++ b/pkg/metric/metric_test.go
@@ -59,8 +59,9 @@ func reset() {
}
const (
- fooDescription = "Foo!"
- barDescription = "Bar Baz"
+ fooDescription = "Foo!"
+ barDescription = "Bar Baz"
+ counterDescription = "Counter"
)
func TestInitialize(t *testing.T) {
@@ -95,7 +96,7 @@ func TestInitialize(t *testing.T) {
foundBar := false
for _, m := range mr.Metrics {
if m.Type != pb.MetricMetadata_TYPE_UINT64 {
- t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_TYPE_UINT64)
+ t.Errorf("Metadata %+v Type got %v want pb.MetricMetadata_TYPE_UINT64", m, m.Type)
}
if !m.Cumulative {
t.Errorf("Metadata %+v Cumulative got false want true", m)
@@ -256,3 +257,88 @@ func TestEmitMetricUpdate(t *testing.T) {
t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value)
}
}
+
+func TestEmitMetricUpdateWithFields(t *testing.T) {
+ defer reset()
+
+ field := Field{
+ name: "weirdness_type",
+ allowedValues: []string{"weird1", "weird2"}}
+
+ counter, err := NewUint64Metric("/weirdness", false, pb.MetricMetadata_UNITS_NONE, counterDescription, field)
+ if err != nil {
+ t.Fatalf("NewUint64Metric got err %v want nil", err)
+ }
+
+ Initialize()
+
+ // Don't care about the registration metrics.
+ emitter.Reset()
+ EmitMetricUpdate()
+
+ // For metrics with fields, we do not emit data unless the value is
+ // incremented.
+ if len(emitter) != 0 {
+ t.Fatalf("EmitMetricUpdate emitted %d events want 0", len(emitter))
+ }
+
+ counter.IncrementBy(4, "weird1")
+ counter.Increment("weird2")
+
+ emitter.Reset()
+ EmitMetricUpdate()
+
+ if len(emitter) != 1 {
+ t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter))
+ }
+
+ update, ok := emitter[0].(*pb.MetricUpdate)
+ if !ok {
+ t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0])
+ }
+
+ if len(update.Metrics) != 2 {
+ t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics))
+ }
+
+ foundWeird1 := false
+ foundWeird2 := false
+ for i := 0; i < len(update.Metrics); i++ {
+ m := update.Metrics[i]
+
+ if m.Name != "/weirdness" {
+ t.Errorf("Metric %+v name got %q want '/weirdness'", m, m.Name)
+ }
+ if len(m.FieldValues) != 1 {
+ t.Errorf("MetricUpdate got %d fields want 1", len(m.FieldValues))
+ }
+
+ switch m.FieldValues[0] {
+ case "weird1":
+ uv, ok := m.Value.(*pb.MetricValue_Uint64Value)
+ if !ok {
+ t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value)
+ }
+ if uv.Uint64Value != 4 {
+ t.Errorf("%v: Value got %v want 4", m, uv.Uint64Value)
+ }
+ foundWeird1 = true
+ case "weird2":
+ uv, ok := m.Value.(*pb.MetricValue_Uint64Value)
+ if !ok {
+ t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value)
+ }
+ if uv.Uint64Value != 1 {
+ t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value)
+ }
+ foundWeird2 = true
+ }
+ }
+
+ if !foundWeird1 {
+ t.Errorf("Field value weird1 not found: %+v", emitter)
+ }
+ if !foundWeird2 {
+ t.Errorf("Field value weird2 not found: %+v", emitter)
+ }
+}
diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go
index 0472eca3f..fb8984dd6 100644
--- a/pkg/refsvfs2/refs_map.go
+++ b/pkg/refsvfs2/refs_map.go
@@ -112,20 +112,27 @@ func logEvent(obj CheckedObject, msg string) {
log.Infof("[%s %p] %s:\n%s", obj.RefType(), obj, msg, refs_vfs1.FormatStack(refs_vfs1.RecordStack()))
}
+// checkOnce makes sure that leak checking is only done once. DoLeakCheck is
+// called from multiple places (which may overlap) to cover different sandbox
+// exit scenarios.
+var checkOnce sync.Once
+
// DoLeakCheck iterates through the live object map and logs a message for each
// object. It is called once no reference-counted objects should be reachable
// anymore, at which point anything left in the map is considered a leak.
func DoLeakCheck() {
if leakCheckEnabled() {
- liveObjectsMu.Lock()
- defer liveObjectsMu.Unlock()
- leaked := len(liveObjects)
- if leaked > 0 {
- msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked)
- for obj := range liveObjects {
- msg += obj.LeakMessage() + "\n"
+ checkOnce.Do(func() {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ leaked := len(liveObjects)
+ if leaked > 0 {
+ msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked)
+ for obj := range liveObjects {
+ msg += obj.LeakMessage() + "\n"
+ }
+ log.Warningf(msg)
}
- log.Warningf(msg)
- }
+ })
}
}
diff --git a/pkg/ring0/BUILD b/pkg/ring0/BUILD
index 885958456..fda6ba601 100644
--- a/pkg/ring0/BUILD
+++ b/pkg/ring0/BUILD
@@ -77,10 +77,10 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/cpuid",
+ "//pkg/hostarch",
"//pkg/ring0/pagetables",
"//pkg/safecopy",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
- "//pkg/usermem",
],
)
diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go
index ceddf719d..76776c65c 100644
--- a/pkg/ring0/defs_amd64.go
+++ b/pkg/ring0/defs_amd64.go
@@ -17,7 +17,7 @@
package ring0
import (
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
var (
@@ -25,7 +25,7 @@ var (
UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1)
// MaximumUserAddress is the largest possible user address.
- MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1)
// KernelStartAddress is the starting kernel address.
KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
diff --git a/pkg/ring0/defs_arm64.go b/pkg/ring0/defs_arm64.go
index c372b02bb..0125690d2 100644
--- a/pkg/ring0/defs_arm64.go
+++ b/pkg/ring0/defs_arm64.go
@@ -17,7 +17,7 @@
package ring0
import (
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
var (
@@ -25,7 +25,7 @@ var (
UserspaceSize = uintptr(1) << (VirtualAddressBits())
// MaximumUserAddress is the largest possible user address.
- MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(hostarch.PageSize-1)
// KernelStartAddress is the starting kernel address.
KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
diff --git a/pkg/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD
index f421e1687..9ea8f9a4f 100644
--- a/pkg/ring0/gen_offsets/BUILD
+++ b/pkg/ring0/gen_offsets/BUILD
@@ -33,9 +33,9 @@ go_binary(
],
deps = [
"//pkg/cpuid",
+ "//pkg/hostarch",
"//pkg/ring0/pagetables",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
- "//pkg/usermem",
],
)
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 33c259757..41dfd0bf9 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -20,7 +20,7 @@ import (
"encoding/binary"
"reflect"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// init initializes architecture-specific state.
@@ -34,7 +34,7 @@ func (k *Kernel) init(maxCPUs int) {
entries = make([]kernelEntry, maxCPUs+padding-1)
totalSize := entrySize * uintptr(maxCPUs+padding-1)
addr := reflect.ValueOf(&entries[0]).Pointer()
- if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize {
+ if addr&(hostarch.PageSize-1) == 0 && totalSize >= hostarch.PageSize {
// The runtime forces power-of-2 alignment for allocations, and we are therefore
// safe once the first address is aligned and the chunk is at least a full page.
break
@@ -44,10 +44,10 @@ func (k *Kernel) init(maxCPUs int) {
k.cpuEntries = entries
k.globalIDT = &idt64{}
- if reflect.TypeOf(idt64{}).Size() != usermem.PageSize {
+ if reflect.TypeOf(idt64{}).Size() != hostarch.PageSize {
panic("Size of globalIDT should be PageSize")
}
- if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 {
+ if reflect.ValueOf(k.globalIDT).Pointer()&(hostarch.PageSize-1) != 0 {
panic("Allocated globalIDT should be page aligned")
}
@@ -71,13 +71,13 @@ func (k *Kernel) EntryRegions() map[uintptr]uintptr {
addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer()
size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries))
- end, _ := usermem.Addr(addr + size).RoundUp()
- regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+ end, _ := hostarch.Addr(addr + size).RoundUp()
+ regions[uintptr(hostarch.Addr(addr).RoundDown())] = uintptr(end)
addr = reflect.ValueOf(k.globalIDT).Pointer()
size = reflect.TypeOf(idt64{}).Size()
- end, _ = usermem.Addr(addr + size).RoundUp()
- regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+ end, _ = hostarch.Addr(addr + size).RoundUp()
+ regions[uintptr(hostarch.Addr(addr).RoundDown())] = uintptr(end)
return regions
}
@@ -250,6 +250,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
}
SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
+ RestoreKernelFPState() // escapes: no. Restore kernel MXCSR.
return
}
@@ -321,3 +322,21 @@ func SetCPUIDFaulting(on bool) bool {
func ReadCR2() uintptr {
return readCR2()
}
+
+// kernelMXCSR is the value of the mxcsr register in the Sentry.
+//
+// The MXCSR control configuration is initialized once and never changed. Look
+// at src/cmd/compile/abi-internal.md in the golang sources for more details.
+var kernelMXCSR uint32
+
+// RestoreKernelFPState restores the Sentry floating point state.
+//
+//go:nosplit
+func RestoreKernelFPState() {
+ // Restore the MXCSR control configuration.
+ ldmxcsr(&kernelMXCSR)
+}
+
+func init() {
+ stmxcsr(&kernelMXCSR)
+}
diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go
index 7975e5f92..21db910a2 100644
--- a/pkg/ring0/kernel_arm64.go
+++ b/pkg/ring0/kernel_arm64.go
@@ -65,7 +65,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer())
if switchOpts.Flush {
- FlushTlbByASID(uintptr(switchOpts.UserASID))
+ LocalFlushTlbByASID(uintptr(switchOpts.UserASID))
}
regs := switchOpts.Registers
@@ -89,3 +89,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
return
}
+
+// RestoreKernelFPState restores the Sentry floating point state.
+//
+//go:nosplit
+func RestoreKernelFPState() {
+}
diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go
index 0ec5c3bc5..3e6bb9663 100644
--- a/pkg/ring0/lib_amd64.go
+++ b/pkg/ring0/lib_amd64.go
@@ -61,6 +61,12 @@ func wrgsbase(addr uintptr)
// wrgsmsr writes to the GS_BASE MSR.
func wrgsmsr(addr uintptr)
+// stmxcsr reads the MXCSR control and status register.
+func stmxcsr(addr *uint32)
+
+// ldmxcsr writes to the MXCSR control and status register.
+func ldmxcsr(addr *uint32)
+
// readCR2 reads the current CR2 value.
func readCR2() uintptr
diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s
index 2fe83568a..70a43e79e 100644
--- a/pkg/ring0/lib_amd64.s
+++ b/pkg/ring0/lib_amd64.s
@@ -198,3 +198,15 @@ TEXT ·rdmsr(SB),NOSPLIT,$0-16
MOVL AX, ret+8(FP)
MOVL DX, ret+12(FP)
RET
+
+// stmxcsr reads the MXCSR control and status register.
+TEXT ·stmxcsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), SI
+ STMXCSR (SI)
+ RET
+
+// ldmxcsr writes to the MXCSR control and status register.
+TEXT ·ldmxcsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), SI
+ LDMXCSR (SI)
+ RET
diff --git a/pkg/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go
index edf24eda3..5eabd4296 100644
--- a/pkg/ring0/lib_arm64.go
+++ b/pkg/ring0/lib_arm64.go
@@ -31,6 +31,9 @@ func FlushTlbByVA(addr uintptr)
// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable.
func FlushTlbByASID(asid uintptr)
+// LocalFlushTlbByASID invalidates tlb by ASID.
+func LocalFlushTlbByASID(asid uintptr)
+
// FlushTlbAll invalidates all tlb.
func FlushTlbAll()
@@ -62,9 +65,10 @@ func LoadFloatingPoint(*byte)
// SaveFloatingPoint saves floating point state.
func SaveFloatingPoint(*byte)
+// FPSIMDDisableTrap disables fpsimd.
func FPSIMDDisableTrap()
-// DisableVFP disables fpsimd.
+// FPSIMDEnableTrap enables fpsimd.
func FPSIMDEnableTrap()
// Init sets function pointers based on architectural features.
diff --git a/pkg/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s
index e39b32841..69ebaf519 100644
--- a/pkg/ring0/lib_arm64.s
+++ b/pkg/ring0/lib_arm64.s
@@ -32,6 +32,14 @@ TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8
DSB $11 // dsb(ish)
RET
+TEXT ·LocalFlushTlbByASID(SB),NOSPLIT,$0-8
+ MOVD asid+0(FP), R1
+ LSL $TLBI_ASID_SHIFT, R1, R1
+ DSB $10 // dsb(ishst)
+ WORD $0xd5088741 // tlbi aside1, x1
+ DSB $11 // dsb(ish)
+ RET
+
TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0
DSB $6 // dsb(nshst)
WORD $0xd508871f // __tlbi(vmalle1)
diff --git a/pkg/ring0/pagetables/BUILD b/pkg/ring0/pagetables/BUILD
index 65a978cbb..f855f4d42 100644
--- a/pkg/ring0/pagetables/BUILD
+++ b/pkg/ring0/pagetables/BUILD
@@ -68,8 +68,8 @@ go_library(
"//pkg/sentry/platform/kvm:__subpackages__",
],
deps = [
+ "//pkg/hostarch",
"//pkg/sync",
- "//pkg/usermem",
],
)
@@ -84,5 +84,5 @@ go_test(
":walker_check_arm64",
],
library = ":pagetables",
- deps = ["//pkg/usermem"],
+ deps = ["//pkg/hostarch"],
)
diff --git a/pkg/ring0/pagetables/allocator_unsafe.go b/pkg/ring0/pagetables/allocator_unsafe.go
index d08bfdeb3..191d0942b 100644
--- a/pkg/ring0/pagetables/allocator_unsafe.go
+++ b/pkg/ring0/pagetables/allocator_unsafe.go
@@ -17,23 +17,23 @@ package pagetables
import (
"unsafe"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// newAlignedPTEs returns a set of aligned PTEs.
func newAlignedPTEs() *PTEs {
ptes := new(PTEs)
- offset := physicalFor(ptes) & (usermem.PageSize - 1)
+ offset := physicalFor(ptes) & (hostarch.PageSize - 1)
if offset == 0 {
// Already aligned.
return ptes
}
// Need to force an aligned allocation.
- unaligned := make([]byte, (2*usermem.PageSize)-1)
- offset = uintptr(unsafe.Pointer(&unaligned[0])) & (usermem.PageSize - 1)
+ unaligned := make([]byte, (2*hostarch.PageSize)-1)
+ offset = uintptr(unsafe.Pointer(&unaligned[0])) & (hostarch.PageSize - 1)
if offset != 0 {
- offset = usermem.PageSize - offset
+ offset = hostarch.PageSize - offset
}
return (*PTEs)(unsafe.Pointer(&unaligned[offset]))
}
diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go
index 8c0a6aa82..3f17fba49 100644
--- a/pkg/ring0/pagetables/pagetables.go
+++ b/pkg/ring0/pagetables/pagetables.go
@@ -21,7 +21,7 @@
package pagetables
import (
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// PageTables is a set of page tables.
@@ -142,7 +142,7 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
// +checkescape:hard,stack
//go:nosplit
-func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
+func (p *PageTables) Map(addr hostarch.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
if p.readOnlyShared {
panic("Should not modify read-only shared pagetables.")
}
@@ -198,7 +198,7 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) bool {
//
// +checkescape:hard,stack
//go:nosplit
-func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
+func (p *PageTables) Unmap(addr hostarch.Addr, length uintptr) bool {
if p.readOnlyShared {
panic("Should not modify read-only shared pagetables.")
}
@@ -249,7 +249,7 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) bool {
//
// +checkescape:hard,stack
//go:nosplit
-func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool {
+func (p *PageTables) IsEmpty(addr hostarch.Addr, length uintptr) bool {
w := emptyWalker{
pageTables: p,
}
@@ -298,9 +298,9 @@ func (*lookupVisitor) requiresSplit() bool { return false }
//
// +checkescape:hard,stack
//go:nosplit
-func (p *PageTables) Lookup(addr usermem.Addr, findFirst bool) (virtual usermem.Addr, physical, size uintptr, opts MapOpts) {
- mask := uintptr(usermem.PageSize - 1)
- addr &^= usermem.Addr(mask)
+func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarch.Addr, physical, size uintptr, opts MapOpts) {
+ mask := uintptr(hostarch.PageSize - 1)
+ addr &^= hostarch.Addr(mask)
w := lookupWalker{
pageTables: p,
visitor: lookupVisitor{
@@ -308,12 +308,12 @@ func (p *PageTables) Lookup(addr usermem.Addr, findFirst bool) (virtual usermem.
findFirst: findFirst,
},
}
- end := ^usermem.Addr(0) &^ usermem.Addr(mask)
+ end := ^hostarch.Addr(0) &^ hostarch.Addr(mask)
if !findFirst {
end = addr + 1
}
w.iterateRange(uintptr(addr), uintptr(end))
- return usermem.Addr(w.visitor.target), w.visitor.physical, w.visitor.size, w.visitor.opts
+ return hostarch.Addr(w.visitor.target), w.visitor.physical, w.visitor.size, w.visitor.opts
}
// MarkReadOnlyShared marks the pagetables read-only and can be shared.
diff --git a/pkg/ring0/pagetables/pagetables_aarch64.go b/pkg/ring0/pagetables/pagetables_aarch64.go
index 163a3aea3..86eb00a4f 100644
--- a/pkg/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/ring0/pagetables/pagetables_aarch64.go
@@ -19,7 +19,7 @@ package pagetables
import (
"sync/atomic"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// archPageTables is architecture-specific data.
@@ -85,7 +85,7 @@ const (
// MapOpts are x86 options.
type MapOpts struct {
// AccessType defines permissions.
- AccessType usermem.AccessType
+ AccessType hostarch.AccessType
// Global indicates the page is globally accessible.
Global bool
@@ -120,7 +120,7 @@ func (p *PTE) Opts() MapOpts {
v := atomic.LoadUintptr((*uintptr)(p))
return MapOpts{
- AccessType: usermem.AccessType{
+ AccessType: hostarch.AccessType{
Read: true,
Write: v&readOnly == 0,
Execute: v&xn == 0,
diff --git a/pkg/ring0/pagetables/pagetables_amd64_test.go b/pkg/ring0/pagetables/pagetables_amd64_test.go
index 54e8e554f..a13c616ae 100644
--- a/pkg/ring0/pagetables/pagetables_amd64_test.go
+++ b/pkg/ring0/pagetables/pagetables_amd64_test.go
@@ -19,19 +19,19 @@ package pagetables
import (
"testing"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
func Test2MAnd4K(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a small page and a huge page.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
- pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42)
+ pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: hostarch.Read}, pmdSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
- {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}},
+ {0x00007f0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read}},
})
}
@@ -39,12 +39,12 @@ func Test1GAnd4K(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a small page and a super page.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
- pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42)
+ pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: hostarch.Read}, pudSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
- {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}},
+ {0x00007f0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read}},
})
}
@@ -52,12 +52,12 @@ func TestSplit1GPage(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a super page and knock out the middle.
- pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: usermem.Read}, pudSize*42)
- pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize))
+ pt.Map(0x00007f0000000000, pudSize, MapOpts{AccessType: hostarch.Read}, pudSize*42)
+ pt.Unmap(hostarch.Addr(0x00007f0000000000+pteSize), pudSize-(2*pteSize))
checkMappings(t, pt, []mapping{
- {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read}},
- {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read}},
+ {0x00007f0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read}},
+ {0x00007f0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read}},
})
}
@@ -65,11 +65,11 @@ func TestSplit2MPage(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a huge page and knock out the middle.
- pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: usermem.Read}, pmdSize*42)
- pt.Unmap(usermem.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize))
+ pt.Map(0x00007f0000000000, pmdSize, MapOpts{AccessType: hostarch.Read}, pmdSize*42)
+ pt.Unmap(hostarch.Addr(0x00007f0000000000+pteSize), pmdSize-(2*pteSize))
checkMappings(t, pt, []mapping{
- {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read}},
- {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read}},
+ {0x00007f0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read}},
+ {0x00007f0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read}},
})
}
diff --git a/pkg/ring0/pagetables/pagetables_arm64_test.go b/pkg/ring0/pagetables/pagetables_arm64_test.go
index 2f73d424f..2514b9ac5 100644
--- a/pkg/ring0/pagetables/pagetables_arm64_test.go
+++ b/pkg/ring0/pagetables/pagetables_arm64_test.go
@@ -19,24 +19,24 @@ package pagetables
import (
"testing"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
func Test2MAnd4K(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a small page and a huge page.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
- pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*47)
- pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42)
- pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47)
+ pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: false}, pteSize*42)
+ pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: false}, pmdSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
- {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
- {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}},
- {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}},
+ {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: true}},
+ {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: false}},
+ {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: false}},
})
}
@@ -44,12 +44,12 @@ func Test1GAnd4K(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a small page and a super page.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
- pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
- {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}},
+ {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read, User: true}},
})
}
@@ -57,12 +57,12 @@ func TestSplit1GPage(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a super page and knock out the middle.
- pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42)
- pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize))
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*42)
+ pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize))
checkMappings(t, pt, []mapping{
- {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
- {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read, User: true}},
+ {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}},
})
}
@@ -70,11 +70,11 @@ func TestSplit2MPage(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map a huge page and knock out the middle.
- pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42)
- pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize))
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*42)
+ pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize))
checkMappings(t, pt, []mapping{
- {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
- {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read, User: true}},
+ {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}},
})
}
diff --git a/pkg/ring0/pagetables/pagetables_test.go b/pkg/ring0/pagetables/pagetables_test.go
index 772f4fc5e..df93dcb6a 100644
--- a/pkg/ring0/pagetables/pagetables_test.go
+++ b/pkg/ring0/pagetables/pagetables_test.go
@@ -15,9 +15,8 @@
package pagetables
import (
+ "gvisor.dev/gvisor/pkg/hostarch"
"testing"
-
- "gvisor.dev/gvisor/pkg/usermem"
)
type mapping struct {
@@ -90,7 +89,7 @@ func TestUnmap(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map and unmap one entry.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42)
pt.Unmap(0x400000, pteSize)
checkMappings(t, pt, nil)
@@ -100,10 +99,10 @@ func TestReadOnly(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map one entry.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*42)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.Read}},
})
}
@@ -111,10 +110,10 @@ func TestReadWrite(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map one entry.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}},
})
}
@@ -122,12 +121,12 @@ func TestSerialEntries(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map two sequential entries.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
- pt.Map(0x401000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42)
+ pt.Map(0x401000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
- {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.ReadWrite}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}},
+ {0x401000, pteSize, pteSize * 47, MapOpts{AccessType: hostarch.ReadWrite}},
})
}
@@ -135,11 +134,11 @@ func TestSpanningEntries(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Span a pgd with two pages.
- pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: usermem.Read}, pteSize*42)
+ pt.Map(0x00007efffffff000, 2*pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*42)
checkMappings(t, pt, []mapping{
- {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.Read}},
- {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: usermem.Read}},
+ {0x00007efffffff000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.Read}},
+ {0x00007f0000000000, pteSize, pteSize * 43, MapOpts{AccessType: hostarch.Read}},
})
}
@@ -147,11 +146,11 @@ func TestSparseEntries(t *testing.T) {
pt := New(NewRuntimeAllocator())
// Map two entries in different pgds.
- pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite}, pteSize*42)
- pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: usermem.Read}, pteSize*47)
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite}, pteSize*42)
+ pt.Map(0x00007f0000000000, pteSize, MapOpts{AccessType: hostarch.Read}, pteSize*47)
checkMappings(t, pt, []mapping{
- {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite}},
- {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: usermem.Read}},
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite}},
+ {0x00007f0000000000, pteSize, pteSize * 47, MapOpts{AccessType: hostarch.Read}},
})
}
diff --git a/pkg/ring0/pagetables/pagetables_x86.go b/pkg/ring0/pagetables/pagetables_x86.go
index 32edd2f0a..e43698173 100644
--- a/pkg/ring0/pagetables/pagetables_x86.go
+++ b/pkg/ring0/pagetables/pagetables_x86.go
@@ -19,7 +19,7 @@ package pagetables
import (
"sync/atomic"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// archPageTables is architecture-specific data.
@@ -63,7 +63,7 @@ const (
// MapOpts are x86 options.
type MapOpts struct {
// AccessType defines permissions.
- AccessType usermem.AccessType
+ AccessType hostarch.AccessType
// Global indicates the page is globally accessible.
Global bool
@@ -97,7 +97,7 @@ func (p *PTE) Valid() bool {
func (p *PTE) Opts() MapOpts {
v := atomic.LoadUintptr((*uintptr)(p))
return MapOpts{
- AccessType: usermem.AccessType{
+ AccessType: hostarch.AccessType{
Read: v&present != 0,
Write: v&writable != 0,
Execute: v&executeDisable == 0,
diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s
index a0cd78f33..290579e53 100644
--- a/pkg/safecopy/atomic_amd64.s
+++ b/pkg/safecopy/atomic_amd64.s
@@ -44,6 +44,12 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24
MOVL AX, old+16(FP)
RET
+// func addrOfSwapUint32() uintptr
+TEXT ·addrOfSwapUint32(SB), $0-8
+ MOVQ $·swapUint32(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// handleSwapUint64Fault returns the value stored in DI. Control is transferred
// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
// number stored in DI.
@@ -74,6 +80,12 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28
MOVQ AX, old+16(FP)
RET
+// func addrOfSwapUint64() uintptr
+TEXT ·addrOfSwapUint64(SB), $0-8
+ MOVQ $·swapUint64(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// handleCompareAndSwapUint32Fault returns the value stored in DI. Control is
// transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the
// signal number stored in DI.
@@ -107,6 +119,12 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24
MOVL AX, prev+16(FP)
RET
+// func addrOfCompareAndSwapUint32() uintptr
+TEXT ·addrOfCompareAndSwapUint32(SB), $0-8
+ MOVQ $·compareAndSwapUint32(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// handleLoadUint32Fault returns the value stored in DI. Control is transferred
// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
// number stored in DI.
@@ -134,3 +152,9 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16
MOVL (AX), BX
MOVL BX, val+8(FP)
RET
+
+// func addrOfLoadUint32() uintptr
+TEXT ·addrOfLoadUint32(SB), $0-8
+ MOVQ $·loadUint32(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s
index d58ed71f7..55c031a3c 100644
--- a/pkg/safecopy/atomic_arm64.s
+++ b/pkg/safecopy/atomic_arm64.s
@@ -33,6 +33,12 @@ again:
MOVW R2, old+16(FP)
RET
+// func addrOfSwapUint32() uintptr
+TEXT ·addrOfSwapUint32(SB), $0-8
+ MOVD $·swapUint32(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// handleSwapUint64Fault returns the value stored in R1. Control is transferred
// to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal
// number stored in R1.
@@ -62,6 +68,12 @@ again:
MOVD R2, old+16(FP)
RET
+// func addrOfSwapUint64() uintptr
+TEXT ·addrOfSwapUint64(SB), $0-8
+ MOVD $·swapUint64(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// handleCompareAndSwapUint32Fault returns the value stored in R1. Control is
// transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS,
// with the signal number stored in R1.
@@ -97,6 +109,12 @@ done:
MOVW R3, prev+16(FP)
RET
+// func addrOfCompareAndSwapUint32() uintptr
+TEXT ·addrOfCompareAndSwapUint32(SB), $0-8
+ MOVD $·compareAndSwapUint32(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// handleLoadUint32Fault returns the value stored in DI. Control is transferred
// to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal
// number stored in DI.
@@ -124,3 +142,9 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16
LDARW (R0), R1
MOVW R1, val+8(FP)
RET
+
+// func addrOfLoadUint32() uintptr
+TEXT ·addrOfLoadUint32(SB), $0-8
+ MOVD $·loadUint32(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s
index 64cf32f05..4abaecaff 100644
--- a/pkg/safecopy/memclr_amd64.s
+++ b/pkg/safecopy/memclr_amd64.s
@@ -145,3 +145,9 @@ _129through256:
MOVOU X0, -32(DI)(BX*1)
MOVOU X0, -16(DI)(BX*1)
RET
+
+// func addrOfMemclr() uintptr
+TEXT ·addrOfMemclr(SB), $0-8
+ MOVQ $·memclr(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s
index 7361b9067..c789bfeb3 100644
--- a/pkg/safecopy/memclr_arm64.s
+++ b/pkg/safecopy/memclr_arm64.s
@@ -72,3 +72,9 @@ head_loop:
CMP $16, R1
BLT tail_zero
B aligned_to_16
+
+// func addrOfMemclr() uintptr
+TEXT ·addrOfMemclr(SB), $0-8
+ MOVD $·memclr(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s
index 00b46c18f..1d63ca1fd 100644
--- a/pkg/safecopy/memcpy_amd64.s
+++ b/pkg/safecopy/memcpy_amd64.s
@@ -217,3 +217,9 @@ move_129through256:
MOVOU -16(SI)(BX*1), X15
MOVOU X15, -16(DI)(BX*1)
RET
+
+// func addrOfMemcpy() uintptr
+TEXT ·addrOfMemcpy(SB), $0-8
+ MOVQ $·memcpy(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s
index e7e541565..7b3f50aa5 100644
--- a/pkg/safecopy/memcpy_arm64.s
+++ b/pkg/safecopy/memcpy_arm64.s
@@ -76,3 +76,9 @@ forwardtailloop:
CMP R3, R9
BNE forwardtailloop
RET
+
+// func addrOfMemcpy() uintptr
+TEXT ·addrOfMemcpy(SB), $0-8
+ MOVD $·memcpy(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index 1e0af5889..df63dd5f1 100644
--- a/pkg/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
@@ -18,7 +18,6 @@ package safecopy
import (
"fmt"
- "reflect"
"runtime"
"golang.org/x/sys/unix"
@@ -91,6 +90,11 @@ var (
// signals.
func signalHandler()
+// addrOfSignalHandler returns the start address of signalHandler.
+//
+// See comment on addrOfMemcpy for more details.
+func addrOfSignalHandler() uintptr
+
// FindEndAddress returns the end address (one byte beyond the last) of the
// function that contains the specified address (begin).
func FindEndAddress(begin uintptr) uintptr {
@@ -111,26 +115,26 @@ func initializeAddresses() {
// The following functions are written in assembly language, so they won't
// be inlined by the existing compiler/linker. Tests will fail if this
// assumption is violated.
- memcpyBegin = reflect.ValueOf(memcpy).Pointer()
+ memcpyBegin = addrOfMemcpy()
memcpyEnd = FindEndAddress(memcpyBegin)
- memclrBegin = reflect.ValueOf(memclr).Pointer()
+ memclrBegin = addrOfMemclr()
memclrEnd = FindEndAddress(memclrBegin)
- swapUint32Begin = reflect.ValueOf(swapUint32).Pointer()
+ swapUint32Begin = addrOfSwapUint32()
swapUint32End = FindEndAddress(swapUint32Begin)
- swapUint64Begin = reflect.ValueOf(swapUint64).Pointer()
+ swapUint64Begin = addrOfSwapUint64()
swapUint64End = FindEndAddress(swapUint64Begin)
- compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer()
+ compareAndSwapUint32Begin = addrOfCompareAndSwapUint32()
compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin)
- loadUint32Begin = reflect.ValueOf(loadUint32).Pointer()
+ loadUint32Begin = addrOfLoadUint32()
loadUint32End = FindEndAddress(loadUint32Begin)
}
func init() {
initializeAddresses()
- if err := ReplaceSignalHandler(unix.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil {
+ if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
}
- if err := ReplaceSignalHandler(unix.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil {
+ if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
}
syserror.AddErrorUnwrapper(func(e error) (unix.Errno, bool) {
diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go
index d2ce8ff86..55743e69c 100644
--- a/pkg/safecopy/safecopy_test.go
+++ b/pkg/safecopy/safecopy_test.go
@@ -19,15 +19,13 @@ import (
"fmt"
"io/ioutil"
"math/rand"
- "os"
- "runtime/debug"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
-// Size of a page in bytes. Cloned from usermem.PageSize to avoid a circular
+// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular
// dependency.
const pageSize = 4096
@@ -568,63 +566,3 @@ func TestCompareAndSwapUint32BusError(t *testing.T) {
}
})
}
-
-func testCopy(dst, src []byte) (panicked bool) {
- defer func() {
- if r := recover(); r != nil {
- panicked = true
- }
- }()
- debug.SetPanicOnFault(true)
- copy(dst, src)
- return
-}
-
-func TestSegVOnMemmove(t *testing.T) {
- // Test that SIGSEGVs received by runtime.memmove when *not* doing
- // CopyIn or CopyOut work gets propagated to the runtime.
- const bufLen = pageSize
- a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE)
- if err != nil {
- t.Fatalf("Mmap failed: %v", err)
-
- }
- defer unix.Munmap(a)
- b := randBuf(bufLen)
-
- if !testCopy(b, a) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-
- if !testCopy(a, b) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-}
-
-func TestSigbusOnMemmove(t *testing.T) {
- // Test that SIGBUS received by runtime.memmove when *not* doing
- // CopyIn or CopyOut work gets propagated to the runtime.
- const bufLen = pageSize
- f, err := ioutil.TempFile("", "sigbus_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- os.Remove(f.Name())
- defer f.Close()
-
- a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED)
- if err != nil {
- t.Fatalf("Mmap failed: %v", err)
-
- }
- defer unix.Munmap(a)
- b := randBuf(bufLen)
-
- if !testCopy(b, a) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-
- if !testCopy(a, b) {
- t.Fatalf("testCopy didn't panic when it should have")
- }
-}
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index a075cf88e..efbc2ddc1 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -89,6 +89,18 @@ func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig
//go:noescape
func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
+// Return the start address of the functions above.
+//
+// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal
+// wrapper function rather than the function itself. We must reference from
+// assembly to get the ABI0 (i.e., primary) address.
+func addrOfMemcpy() uintptr
+func addrOfMemclr() uintptr
+func addrOfSwapUint32() uintptr
+func addrOfSwapUint64() uintptr
+func addrOfCompareAndSwapUint32() uintptr
+func addrOfLoadUint32() uintptr
+
// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
// copied and an error if SIGSEGV or SIGBUS is received while reading from src.
func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
diff --git a/pkg/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s
index 475ae48e9..0b5e8df66 100644
--- a/pkg/safecopy/sighandler_amd64.s
+++ b/pkg/safecopy/sighandler_amd64.s
@@ -131,3 +131,9 @@ handle_fault:
MOVL DI, REG_RDI(DX)
RET
+
+// func addrOfSignalHandler() uintptr
+TEXT ·addrOfSignalHandler(SB), $0-8
+ MOVQ $·signalHandler(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s
index 53e4ac2c1..41ed70ff9 100644
--- a/pkg/safecopy/sighandler_arm64.s
+++ b/pkg/safecopy/sighandler_arm64.s
@@ -141,3 +141,9 @@ handle_fault:
MOVW R0, REG_R1(R2)
RET
+
+// func addrOfSignalHandler() uintptr
+TEXT ·addrOfSignalHandler(SB), $0-8
+ MOVD $·signalHandler(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD
index 3fda3a9cc..2c7cc8769 100644
--- a/pkg/safemem/BUILD
+++ b/pkg/safemem/BUILD
@@ -14,6 +14,7 @@ go_library(
deps = [
"//pkg/gohacks",
"//pkg/safecopy",
+ "//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go
index 93879bb4f..4af534385 100644
--- a/pkg/safemem/block_unsafe.go
+++ b/pkg/safemem/block_unsafe.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/sync"
)
// A Block is a range of contiguous bytes, similar to []byte but with the
@@ -223,8 +224,22 @@ func Copy(dst, src Block) (int, error) {
func Zero(dst Block) (int, error) {
if !dst.needSafecopy {
bs := dst.ToSlice()
- for i := range bs {
- bs[i] = 0
+ if !sync.RaceEnabled {
+ // If the race detector isn't enabled, the golang
+ // compiler replaces the next loop with memclr
+ // (https://github.com/golang/go/issues/5373).
+ for i := range bs {
+ bs[i] = 0
+ }
+ } else {
+ bsLen := len(bs)
+ if bsLen == 0 {
+ return 0, nil
+ }
+ bs[0] = 0
+ for i := 1; i < bsLen; i *= 2 {
+ copy(bs[i:], bs[:i])
+ }
}
return len(bs), nil
}
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index 201dd072f..169fc0ac3 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -54,6 +54,6 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/bpf",
- "//pkg/usermem",
+ "//pkg/hostarch",
],
)
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index db06d1f1b..68feddf31 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -29,7 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// newVictim makes a victim binary.
@@ -57,7 +57,7 @@ func dataAsInput(d *linux.SeccompData) bpf.Input {
d.MarshalUnsafe(buf)
return bpf.InputBytes{
Data: buf,
- Order: usermem.ByteOrder,
+ Order: hostarch.ByteOrder,
}
}
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index f660f1614..c9c52530d 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/cpuid",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 921151137..290863ee6 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -22,11 +22,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Arch describes an architecture.
@@ -188,11 +188,11 @@ type Context interface {
// returned layout must be no lower than min, and MaxAddr for the returned
// layout must be no higher than max. Repeated calls to NewMmapLayout may
// return different layouts.
- NewMmapLayout(min, max usermem.Addr, limits *limits.LimitSet) (MmapLayout, error)
+ NewMmapLayout(min, max hostarch.Addr, limits *limits.LimitSet) (MmapLayout, error)
// PIELoadAddress returns a preferred load address for a
// position-independent executable within l.
- PIELoadAddress(l MmapLayout) usermem.Addr
+ PIELoadAddress(l MmapLayout) hostarch.Addr
// FeatureSet returns the FeatureSet in use in this context.
FeatureSet() *cpuid.FeatureSet
@@ -257,18 +257,18 @@ const (
// +stateify savable
type MmapLayout struct {
// MinAddr is the lowest mappable address.
- MinAddr usermem.Addr
+ MinAddr hostarch.Addr
// MaxAddr is the highest mappable address.
- MaxAddr usermem.Addr
+ MaxAddr hostarch.Addr
// BottomUpBase is the lowest address that may be returned for a
// MmapBottomUp mmap.
- BottomUpBase usermem.Addr
+ BottomUpBase hostarch.Addr
// TopDownBase is the highest address that may be returned for a
// MmapTopDown mmap.
- TopDownBase usermem.Addr
+ TopDownBase hostarch.Addr
// DefaultDirection is the direction for most non-fixed mmaps in this
// layout.
@@ -316,9 +316,9 @@ type SyscallArgument struct {
// SyscallArguments represents the set of arguments passed to a syscall.
type SyscallArguments [6]SyscallArgument
-// Pointer returns the usermem.Addr representation of a pointer argument.
-func (a SyscallArgument) Pointer() usermem.Addr {
- return usermem.Addr(a.Value)
+// Pointer returns the hostarch.Addr representation of a pointer argument.
+func (a SyscallArgument) Pointer() hostarch.Addr {
+ return hostarch.Addr(a.Value)
}
// Int returns the int32 representation of a 32-bit signed integer argument.
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 2571be60f..d6b4d2357 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -23,11 +23,11 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Host specifies the host architecture.
@@ -37,7 +37,7 @@ const Host = AMD64
const (
// maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
// for a 64-bit process.
- maxAddr64 usermem.Addr = (1 << 47) - usermem.PageSize
+ maxAddr64 hostarch.Addr = (1 << 47) - hostarch.PageSize
// maxStackRand64 is the maximum randomization to apply to the stack.
// It is defined by arch/x86/mm/mmap.c:stack_maxrandom_size in Linux.
@@ -45,7 +45,7 @@ const (
// maxMmapRand64 is the maximum randomization to apply to the mmap
// layout. It is defined by arch/x86/mm/mmap.c:arch_mmap_rnd in Linux.
- maxMmapRand64 = (1 << 28) * usermem.PageSize
+ maxMmapRand64 = (1 << 28) * hostarch.PageSize
// minGap64 is the minimum gap to leave at the top of the address space
// for the stack. It is defined by arch/x86/mm/mmap.c:MIN_GAP in Linux.
@@ -56,7 +56,7 @@ const (
//
// The Platform {Min,Max}UserAddress() may preclude loading at this
// address. See other preferredFoo comments below.
- preferredPIELoadAddr usermem.Addr = maxAddr64 / 3 * 2
+ preferredPIELoadAddr hostarch.Addr = maxAddr64 / 3 * 2
)
// These constants are selected as heuristics to help make the Platform's
@@ -92,13 +92,13 @@ const (
// This is all "preferred" because the layout min/max address may not
// allow us to select such a TopDownBase, in which case we have to fall
// back to a layout that TSAN may not be happy with.
- preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
- preferredAllocationGap = 128 << 30 // 128 GB
- preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+ preferredTopDownAllocMin hostarch.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
// minMmapRand64 is the smallest we are willing to make the
// randomization to stay above preferredTopDownBaseMin.
- minMmapRand64 = (1 << 26) * usermem.PageSize
+ minMmapRand64 = (1 << 26) * hostarch.PageSize
)
// context64 represents an AMD64 context.
@@ -207,12 +207,12 @@ func (c *context64) FeatureSet() *cpuid.FeatureSet {
}
// mmapRand returns a random adjustment for randomizing an mmap layout.
-func mmapRand(max uint64) usermem.Addr {
- return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+func mmapRand(max uint64) hostarch.Addr {
+ return hostarch.Addr(rand.Int63n(int64(max))).RoundDown()
}
// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
-func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+func (c *context64) NewMmapLayout(min, max hostarch.Addr, r *limits.LimitSet) (MmapLayout, error) {
min, ok := min.RoundUp()
if !ok {
return MmapLayout{}, unix.EINVAL
@@ -230,7 +230,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm
// MAX_GAP in Linux.
maxGap := (max / 6) * 5
- gap := usermem.Addr(stackSize.Cur)
+ gap := hostarch.Addr(stackSize.Cur)
if gap < minGap64 {
gap = minGap64
}
@@ -243,7 +243,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm
}
topDownMin := max - gap - maxMmapRand64
- maxRand := usermem.Addr(maxMmapRand64)
+ maxRand := hostarch.Addr(maxMmapRand64)
if topDownMin < preferredTopDownBaseMin {
// Try to keep TopDownBase above preferredTopDownBaseMin by
// shrinking maxRand.
@@ -278,7 +278,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm
}
// PIELoadAddress implements Context.PIELoadAddress.
-func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+func (c *context64) PIELoadAddress(l MmapLayout) hostarch.Addr {
base := preferredPIELoadAddr
max, ok := base.AddLength(maxMmapRand64)
if !ok {
@@ -311,7 +311,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
- return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil
+ return c.Native(uintptr(hostarch.ByteOrder.Uint64(buf[addr:]))), nil
}
// Note: x86 debug registers are missing.
return c.Native(0), nil
@@ -326,7 +326,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
- usermem.ByteOrder.PutUint64(buf[addr:], uint64(data))
+ hostarch.ByteOrder.PutUint64(buf[addr:], uint64(data))
_, err := c.PtraceSetRegs(bytes.NewBuffer(buf))
return err
}
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index 14ad9483b..348f238fd 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -22,11 +22,11 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Host specifies the host architecture.
@@ -36,7 +36,7 @@ const Host = ARM64
const (
// maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
// for a 64-bit process.
- maxAddr64 usermem.Addr = (1 << 48)
+ maxAddr64 hostarch.Addr = (1 << 48)
// maxStackRand64 is the maximum randomization to apply to the stack.
// It is defined by arch/arm64/mm/mmap.c:(STACK_RND_MASK << PAGE_SHIFT) in Linux.
@@ -44,7 +44,7 @@ const (
// maxMmapRand64 is the maximum randomization to apply to the mmap
// layout. It is defined by arch/arm64/mm/mmap.c:arch_mmap_rnd in Linux.
- maxMmapRand64 = (1 << 33) * usermem.PageSize
+ maxMmapRand64 = (1 << 33) * hostarch.PageSize
// minGap64 is the minimum gap to leave at the top of the address space
// for the stack. It is defined by arch/arm64/mm/mmap.c:MIN_GAP in Linux.
@@ -55,7 +55,7 @@ const (
//
// The Platform {Min,Max}UserAddress() may preclude loading at this
// address. See other preferredFoo comments below.
- preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
+ preferredPIELoadAddr hostarch.Addr = maxAddr64 / 6 * 5
)
var (
@@ -66,13 +66,13 @@ var (
// These constants are selected as heuristics to help make the Platform's
// potentially limited address space conform as closely to Linux as possible.
const (
- preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
- preferredAllocationGap = 128 << 30 // 128 GB
- preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+ preferredTopDownAllocMin hostarch.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
// minMmapRand64 is the smallest we are willing to make the
// randomization to stay above preferredTopDownBaseMin.
- minMmapRand64 = (1 << 18) * usermem.PageSize
+ minMmapRand64 = (1 << 18) * hostarch.PageSize
)
// context64 represents an ARM64 context.
@@ -187,12 +187,12 @@ func (c *context64) FeatureSet() *cpuid.FeatureSet {
}
// mmapRand returns a random adjustment for randomizing an mmap layout.
-func mmapRand(max uint64) usermem.Addr {
- return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+func mmapRand(max uint64) hostarch.Addr {
+ return hostarch.Addr(rand.Int63n(int64(max))).RoundDown()
}
// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
-func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+func (c *context64) NewMmapLayout(min, max hostarch.Addr, r *limits.LimitSet) (MmapLayout, error) {
min, ok := min.RoundUp()
if !ok {
return MmapLayout{}, unix.EINVAL
@@ -210,7 +210,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm
// MAX_GAP in Linux.
maxGap := (max / 6) * 5
- gap := usermem.Addr(stackSize.Cur)
+ gap := hostarch.Addr(stackSize.Cur)
if gap < minGap64 {
gap = minGap64
}
@@ -223,7 +223,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm
}
topDownMin := max - gap - maxMmapRand64
- maxRand := usermem.Addr(maxMmapRand64)
+ maxRand := hostarch.Addr(maxMmapRand64)
if topDownMin < preferredTopDownBaseMin {
// Try to keep TopDownBase above preferredTopDownBaseMin by
// shrinking maxRand.
@@ -258,7 +258,7 @@ func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (Mm
}
// PIELoadAddress implements Context.PIELoadAddress.
-func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+func (c *context64) PIELoadAddress(l MmapLayout) hostarch.Addr {
base := preferredPIELoadAddr
max, ok := base.AddLength(maxMmapRand64)
if !ok {
diff --git a/pkg/sentry/arch/auxv.go b/pkg/sentry/arch/auxv.go
index 2b4c8f3fc..19ca18121 100644
--- a/pkg/sentry/arch/auxv.go
+++ b/pkg/sentry/arch/auxv.go
@@ -15,7 +15,7 @@
package arch
import (
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// An AuxEntry represents an entry in an ELF auxiliary vector.
@@ -23,7 +23,7 @@ import (
// +stateify savable
type AuxEntry struct {
Key uint64
- Value usermem.Addr
+ Value hostarch.Addr
}
// An Auxv represents an ELF auxiliary vector.
diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD
index 0a5395267..4e4f20639 100644
--- a/pkg/sentry/arch/fpu/BUILD
+++ b/pkg/sentry/arch/fpu/BUILD
@@ -13,9 +13,9 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/cpuid",
+ "//pkg/hostarch",
"//pkg/sync",
"//pkg/syserror",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go
index 3a62f51be..f0ba26736 100644
--- a/pkg/sentry/arch/fpu/fpu_amd64.go
+++ b/pkg/sentry/arch/fpu/fpu_amd64.go
@@ -21,9 +21,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// initX86FPState (defined in asm files) sets up initial state.
@@ -146,11 +146,11 @@ const (
// any of the reserved bits of the MXCSR register." - Intel SDM Vol. 1, Section
// 10.5.1.2 "SSE State")
func sanitizeMXCSR(f State) {
- mxcsr := usermem.ByteOrder.Uint32(f[mxcsrOffset:])
+ mxcsr := hostarch.ByteOrder.Uint32(f[mxcsrOffset:])
initMXCSRMask.Do(func() {
temp := State(alignedBytes(uint(ptraceFPRegsSize), 16))
initX86FPState(&temp[0], false /* useXsave */)
- mxcsrMask = usermem.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
+ mxcsrMask = hostarch.ByteOrder.Uint32(temp[mxcsrMaskOffset:])
if mxcsrMask == 0 {
// "If the value of the MXCSR_MASK field is 00000000H, then the
// MXCSR_MASK value is the default value of 0000FFBFH." - Intel SDM
@@ -160,7 +160,7 @@ func sanitizeMXCSR(f State) {
}
})
mxcsr &= mxcsrMask
- usermem.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
+ hostarch.ByteOrder.PutUint32(f[mxcsrOffset:], mxcsr)
}
// PtraceGetXstateRegs implements ptrace(PTRACE_GETREGS, NT_X86_XSTATE) by
@@ -177,7 +177,7 @@ func (s *State) PtraceGetXstateRegs(dst io.Writer, maxlen int, featureSet *cpuid
// Area". Linux uses the first 8 bytes of this area to store the OS XSTATE
// mask. GDB relies on this: see
// gdb/x86-linux-nat.c:x86_linux_read_description().
- usermem.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask())
+ hostarch.ByteOrder.PutUint64(f[userXstateXCR0Offset:], featureSet.ValidXCR0Mask())
if len(f) > maxlen {
f = f[:maxlen]
}
@@ -208,9 +208,9 @@ func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid
// Force reserved bits in MXCSR to 0. This is consistent with Linux.
sanitizeMXCSR(State(f))
// Users can't enable *more* XCR0 bits than what we, and the CPU, support.
- xstateBV := usermem.ByteOrder.Uint64(f[xstateBVOffset:])
+ xstateBV := hostarch.ByteOrder.Uint64(f[xstateBVOffset:])
xstateBV &= featureSet.ValidXCR0Mask()
- usermem.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
+ hostarch.ByteOrder.PutUint64(f[xstateBVOffset:], xstateBV)
// Force XCOMP_BV and reserved bytes in the XSAVE header to 0.
reserved := f[xsaveHeaderZeroedOffset : xsaveHeaderZeroedOffset+xsaveHeaderZeroedBytes]
for i := range reserved {
@@ -219,6 +219,11 @@ func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid
return copy(*s, f), nil
}
+// SetMXCSR sets the MXCSR control/status register in the state.
+func (s *State) SetMXCSR(mxcsr uint32) {
+ hostarch.ByteOrder.PutUint32((*s)[mxcsrOffset:], mxcsr)
+}
+
// BytePointer returns a pointer to the first byte of the state.
//
//go:nosplit
@@ -266,7 +271,7 @@ func (s *State) AfterLoad() {
// What was in use?
savedBV := fxsaveBV
if len(old) >= xstateBVOffset+8 {
- savedBV = usermem.ByteOrder.Uint64(old[xstateBVOffset:])
+ savedBV = hostarch.ByteOrder.Uint64(old[xstateBVOffset:])
}
// Supported features must be a superset of saved features.
diff --git a/pkg/sentry/arch/fpu/fpu_arm64.go b/pkg/sentry/arch/fpu/fpu_arm64.go
index d2f62631d..46634661f 100644
--- a/pkg/sentry/arch/fpu/fpu_arm64.go
+++ b/pkg/sentry/arch/fpu/fpu_arm64.go
@@ -58,6 +58,8 @@ func (s *State) Fork() State {
}
// BytePointer returns a pointer to the first byte of the state.
+//
+//go:nosplit
func (s *State) BytePointer() *byte {
return &(*s)[0]
}
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
index 35d2e07c3..67d7edf68 100644
--- a/pkg/sentry/arch/signal.go
+++ b/pkg/sentry/arch/signal.go
@@ -16,7 +16,7 @@ package arch
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// SignalAct represents the action that should be taken when a signal is
@@ -154,107 +154,107 @@ func (s *SignalInfo) FixSignalCodeForUser() {
// PID returns the si_pid field.
func (s *SignalInfo) PID() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
}
// SetPID mutates the si_pid field.
func (s *SignalInfo) SetPID(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+ hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
}
// UID returns the si_uid field.
func (s *SignalInfo) UID() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
}
// SetUID mutates the si_uid field.
func (s *SignalInfo) SetUID(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+ hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
}
// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
func (s *SignalInfo) Sigval() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[8:16])
+ return hostarch.ByteOrder.Uint64(s.Fields[8:16])
}
// SetSigval mutates the sigval field.
func (s *SignalInfo) SetSigval(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
+ hostarch.ByteOrder.PutUint64(s.Fields[8:16], val)
}
// TimerID returns the si_timerid field.
func (s *SignalInfo) TimerID() linux.TimerID {
- return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+ return linux.TimerID(hostarch.ByteOrder.Uint32(s.Fields[0:4]))
}
// SetTimerID sets the si_timerid field.
func (s *SignalInfo) SetTimerID(val linux.TimerID) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+ hostarch.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
}
// Overrun returns the si_overrun field.
func (s *SignalInfo) Overrun() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[4:8]))
}
// SetOverrun sets the si_overrun field.
func (s *SignalInfo) SetOverrun(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+ hostarch.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
}
// Addr returns the si_addr field.
func (s *SignalInfo) Addr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
+ return hostarch.ByteOrder.Uint64(s.Fields[0:8])
}
// SetAddr sets the si_addr field.
func (s *SignalInfo) SetAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
}
// Status returns the si_status field.
func (s *SignalInfo) Status() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
}
// SetStatus mutates the si_status field.
func (s *SignalInfo) SetStatus(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
}
// CallAddr returns the si_call_addr field.
func (s *SignalInfo) CallAddr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
+ return hostarch.ByteOrder.Uint64(s.Fields[0:8])
}
// SetCallAddr mutates the si_call_addr field.
func (s *SignalInfo) SetCallAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], val)
}
// Syscall returns the si_syscall field.
func (s *SignalInfo) Syscall() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+ return int32(hostarch.ByteOrder.Uint32(s.Fields[8:12]))
}
// SetSyscall mutates the si_syscall field.
func (s *SignalInfo) SetSyscall(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
}
// Arch returns the si_arch field.
func (s *SignalInfo) Arch() uint32 {
- return usermem.ByteOrder.Uint32(s.Fields[12:16])
+ return hostarch.ByteOrder.Uint32(s.Fields[12:16])
}
// SetArch mutates the si_arch field.
func (s *SignalInfo) SetArch(val uint32) {
- usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
+ hostarch.ByteOrder.PutUint32(s.Fields[12:16], val)
}
// Band returns the si_band field.
func (s *SignalInfo) Band() int64 {
- return int64(usermem.ByteOrder.Uint64(s.Fields[0:8]))
+ return int64(hostarch.ByteOrder.Uint64(s.Fields[0:8]))
}
// SetBand mutates the si_band field.
@@ -262,15 +262,15 @@ func (s *SignalInfo) SetBand(val int64) {
// Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
// On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
// `int`. See siginfo.h.
- usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
+ hostarch.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
}
// FD returns the si_fd field.
func (s *SignalInfo) FD() uint32 {
- return usermem.ByteOrder.Uint32(s.Fields[8:12])
+ return hostarch.ByteOrder.Uint32(s.Fields[8:12])
}
// SetFD mutates the si_fd field.
func (s *SignalInfo) SetFD(val uint32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], val)
+ hostarch.ByteOrder.PutUint32(s.Fields[8:12], val)
}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index ee3743483..082ed92b1 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -21,10 +21,10 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
- "gvisor.dev/gvisor/pkg/usermem"
)
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
@@ -133,7 +133,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// space on the user stack naturally caps the amount of memory the
// sentry will allocate for this purpose.
fpSize, _ := c.fpuFrameSize()
- sp = (sp - usermem.Addr(fpSize)) & ^usermem.Addr(63)
+ sp = (sp - hostarch.Addr(fpSize)) & ^hostarch.Addr(63)
// Construct the UContext64 now since we need its size.
uc := &UContext64{
@@ -180,8 +180,8 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
ucSize := uc.SizeBytes()
// st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
frameSize := int(st.Arch.Width()) + ucSize + 128
- frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
- sp = frameBottom + usermem.Addr(frameSize)
+ frameBottom := (sp-hostarch.Addr(frameSize)) & ^hostarch.Addr(15) - 8
+ sp = frameBottom + hostarch.Addr(frameSize)
st.Bottom = sp
// Prior to proceeding, figure out if the frame will exhaust the range
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 53281dcba..da71fb873 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -19,9 +19,9 @@ package arch
import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
- "gvisor.dev/gvisor/pkg/usermem"
)
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
@@ -107,8 +107,8 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
// sizeof(siginfo) == 128.
// R30 stores the restorer address.
frameSize := ucSize + 128
- frameBottom := (sp - usermem.Addr(frameSize)) & ^usermem.Addr(15)
- sp = frameBottom + usermem.Addr(frameSize)
+ frameBottom := (sp - hostarch.Addr(frameSize)) & ^hostarch.Addr(15)
+ sp = frameBottom + hostarch.Addr(frameSize)
st.Bottom = sp
// Prior to proceeding, figure out if the frame will exhaust the range
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index a1eae98f9..c732c7503 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -17,8 +17,8 @@
package arch
import (
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
- "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -36,8 +36,8 @@ func (s SignalStack) IsEnabled() bool {
}
// Top returns the stack's top address.
-func (s SignalStack) Top() usermem.Addr {
- return usermem.Addr(s.Addr + s.Size)
+func (s SignalStack) Top() hostarch.Addr {
+ return hostarch.Addr(s.Addr + s.Size)
}
// SetOnStack marks this signal stack as in use.
@@ -49,8 +49,8 @@ func (s *SignalStack) SetOnStack() {
}
// Contains checks if the stack pointer is within this stack.
-func (s *SignalStack) Contains(sp usermem.Addr) bool {
- return usermem.Addr(s.Addr) < sp && sp <= usermem.Addr(s.Addr+s.Size)
+func (s *SignalStack) Contains(sp hostarch.Addr) bool {
+ return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size)
}
// NativeSignalStack is a type that is equivalent to stack_t in the guest
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 5f06c751d..65a794c7c 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -16,18 +16,20 @@ package arch
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
+
"gvisor.dev/gvisor/pkg/usermem"
)
-// Stack is a simple wrapper around a usermem.IO and an address. Stack
+// Stack is a simple wrapper around a hostarch.IO and an address. Stack
// implements marshal.CopyContext, and marshallable values can be pushed or
// popped from the stack through the marshal.Marshallable interface.
//
// Stack is not thread-safe.
type Stack struct {
// Our arch info.
- // We use this for automatic Native conversion of usermem.Addrs during
+ // We use this for automatic Native conversion of hostarch.Addrs during
// Push() and Pop().
Arch Context
@@ -35,7 +37,7 @@ type Stack struct {
IO usermem.IO
// Our current stack bottom.
- Bottom usermem.Addr
+ Bottom hostarch.Addr
// Scratch buffer used for marshalling to avoid having to repeatedly
// allocate scratch memory.
@@ -59,20 +61,20 @@ func (s *Stack) CopyScratchBuffer(size int) []byte {
// StackBottomMagic is the special address callers must past to all stack
// marshalling operations to cause the src/dst address to be computed based on
// the current end of the stack.
-const StackBottomMagic = ^usermem.Addr(0) // usermem.Addr(-1)
+const StackBottomMagic = ^hostarch.Addr(0) // hostarch.Addr(-1)
// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. CopyOutBytes
// computes an appropriate address based on the current end of the
// stack. Callers use the sentinel address StackBottomMagic to marshal methods
// to indicate this.
-func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) {
+func (s *Stack) CopyOutBytes(sentinel hostarch.Addr, b []byte) (int, error) {
if sentinel != StackBottomMagic {
panic("Attempted to copy out to stack with absolute address")
}
c := len(b)
- n, err := s.IO.CopyOut(context.Background(), s.Bottom-usermem.Addr(c), b, usermem.IOOpts{})
+ n, err := s.IO.CopyOut(context.Background(), s.Bottom-hostarch.Addr(c), b, usermem.IOOpts{})
if err == nil && n == c {
- s.Bottom -= usermem.Addr(n)
+ s.Bottom -= hostarch.Addr(n)
}
return n, err
}
@@ -81,21 +83,21 @@ func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) {
// an appropriate address based on the current end of the stack. Callers must
// use the sentinel address StackBottomMagic to marshal methods to indicate
// this.
-func (s *Stack) CopyInBytes(sentinel usermem.Addr, b []byte) (int, error) {
+func (s *Stack) CopyInBytes(sentinel hostarch.Addr, b []byte) (int, error) {
if sentinel != StackBottomMagic {
panic("Attempted to copy in from stack with absolute address")
}
n, err := s.IO.CopyIn(context.Background(), s.Bottom, b, usermem.IOOpts{})
if err == nil {
- s.Bottom += usermem.Addr(n)
+ s.Bottom += hostarch.Addr(n)
}
return n, err
}
// Align aligns the stack to the given offset.
func (s *Stack) Align(offset int) {
- if s.Bottom%usermem.Addr(offset) != 0 {
- s.Bottom -= (s.Bottom % usermem.Addr(offset))
+ if s.Bottom%hostarch.Addr(offset) != 0 {
+ s.Bottom -= (s.Bottom % hostarch.Addr(offset))
}
}
@@ -119,16 +121,16 @@ func (s *Stack) PushNullTerminatedByteSlice(bs []byte) (int, error) {
// stack.
type StackLayout struct {
// ArgvStart is the beginning of the argument vector.
- ArgvStart usermem.Addr
+ ArgvStart hostarch.Addr
// ArgvEnd is the end of the argument vector.
- ArgvEnd usermem.Addr
+ ArgvEnd hostarch.Addr
// EnvvStart is the beginning of the environment vector.
- EnvvStart usermem.Addr
+ EnvvStart hostarch.Addr
// EnvvEnd is the end of the environment vector.
- EnvvEnd usermem.Addr
+ EnvvEnd hostarch.Addr
}
// Load pushes the given args, env and aux vector to the stack using the
@@ -148,7 +150,7 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
// to be in this order. See: https://www.uclibc.org/docs/psABI-x86_64.pdf
// page 29.
l.EnvvEnd = s.Bottom
- envAddrs := make([]usermem.Addr, len(env))
+ envAddrs := make([]hostarch.Addr, len(env))
for i := len(env) - 1; i >= 0; i-- {
if _, err := s.PushNullTerminatedByteSlice([]byte(env[i])); err != nil {
return StackLayout{}, err
@@ -159,7 +161,7 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
// Push our strings.
l.ArgvEnd = s.Bottom
- argAddrs := make([]usermem.Addr, len(args))
+ argAddrs := make([]hostarch.Addr, len(args))
for i := len(args) - 1; i >= 0; i-- {
if _, err := s.PushNullTerminatedByteSlice([]byte(args[i])); err != nil {
return StackLayout{}, err
@@ -178,7 +180,7 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
argvSize := s.Arch.Width() * uint(len(args)+1)
envvSize := s.Arch.Width() * uint(len(env)+1)
auxvSize := s.Arch.Width() * 2 * uint(len(aux)+1)
- total := usermem.Addr(argvSize) + usermem.Addr(envvSize) + usermem.Addr(auxvSize) + usermem.Addr(s.Arch.Width())
+ total := hostarch.Addr(argvSize) + hostarch.Addr(envvSize) + hostarch.Addr(auxvSize) + hostarch.Addr(s.Arch.Width())
expectedBottom := s.Bottom - total
if expectedBottom%32 != 0 {
s.Bottom -= expectedBottom % 32
@@ -188,11 +190,11 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error)
// NOTE: We need an extra zero here per spec.
// The Push function will automatically terminate
// strings and arrays with a single null value.
- auxv := make([]usermem.Addr, 0, len(aux))
+ auxv := make([]hostarch.Addr, 0, len(aux))
for _, a := range aux {
- auxv = append(auxv, usermem.Addr(a.Key), a.Value)
+ auxv = append(auxv, hostarch.Addr(a.Key), a.Value)
}
- auxv = append(auxv, usermem.Addr(0))
+ auxv = append(auxv, hostarch.Addr(0))
_, err := s.pushAddrSliceAndTerminator(auxv)
if err != nil {
return StackLayout{}, err
diff --git a/pkg/sentry/arch/stack_unsafe.go b/pkg/sentry/arch/stack_unsafe.go
index 0e478e434..f4712d58f 100644
--- a/pkg/sentry/arch/stack_unsafe.go
+++ b/pkg/sentry/arch/stack_unsafe.go
@@ -17,19 +17,19 @@ package arch
import (
"unsafe"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
- "gvisor.dev/gvisor/pkg/usermem"
)
// pushAddrSliceAndTerminator copies a slices of addresses to the stack, and
// also pushes an extra null address element at the end of the slice.
//
// Internally, we unsafely transmute the slice type from the arch-dependent
-// []usermem.Addr type, to a slice of fixed-sized ints so that we can pass it to
+// []hostarch.Addr type, to a slice of fixed-sized ints so that we can pass it to
// go-marshal.
//
// On error, the contents of the stack and the bottom cursor are undefined.
-func (s *Stack) pushAddrSliceAndTerminator(src []usermem.Addr) (int, error) {
+func (s *Stack) pushAddrSliceAndTerminator(src []hostarch.Addr) (int, error) {
// Note: Stack grows upwards, so push the terminator first.
switch s.Arch.Width() {
case 8:
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
index 1929e41cd..49c53452a 100644
--- a/pkg/sentry/devices/memdev/zero.go
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -93,6 +93,7 @@ func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
// "/dev/zero (deleted)".
opts.Offset = 0
opts.MappingIdentity = &fd.vfsfd
+ opts.SentryOwnedContent = true
opts.MappingIdentity.IncRef()
return nil
}
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
index 71c59287c..8b38d574d 100644
--- a/pkg/sentry/devices/tundev/BUILD
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/arch",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/inet",
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index c43158aa4..a12eeb8e7 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -18,6 +18,7 @@ package tundev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -89,7 +90,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
}
// Validate flags.
- flags, err := netstack.LinuxToTUNFlags(usermem.ByteOrder.Uint16(req.Data[:]))
+ flags, err := netstack.LinuxToTUNFlags(hostarch.ByteOrder.Uint16(req.Data[:]))
if err != nil {
return 0, err
}
@@ -98,7 +99,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
case linux.TUNGETIFF:
var req linux.IFReq
copy(req.IFName[:], fd.device.Name())
- usermem.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(fd.device.Flags()))
+ hostarch.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(fd.device.Flags()))
_, err := req.CopyOut(t, data)
return 0, err
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index 420fbae34..0dc100f9b 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -48,6 +48,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/p9",
"//pkg/refs",
diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD
index aedcecfa1..1ce56d79f 100644
--- a/pkg/sentry/fs/anon/BUILD
+++ b/pkg/sentry/fs/anon/BUILD
@@ -12,9 +12,9 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/anon/anon.go b/pkg/sentry/fs/anon/anon.go
index 5c421f5fb..8bda22a8e 100644
--- a/pkg/sentry/fs/anon/anon.go
+++ b/pkg/sentry/fs/anon/anon.go
@@ -19,9 +19,9 @@ package anon
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/usermem"
)
// NewInode constructs an anonymous Inode that is not associated
@@ -37,6 +37,6 @@ func NewInode(ctx context.Context) *fs.Inode {
Type: fs.Anonymous,
DeviceID: PseudoDevice.DeviceID(),
InodeID: PseudoDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
})
}
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index 58deb25fc..5aa668873 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
@@ -339,7 +340,7 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err
// size is the same used by io.Copy.
var copyUpBuffers = sync.Pool{
New: func() interface{} {
- b := make([]byte, 8*usermem.PageSize)
+ b := make([]byte, 8*hostarch.PageSize)
return &b
},
}
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9379a4d7b..23a3a9a2d 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -18,6 +18,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/rand",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index acbd401a0..e84ba7a5d 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -19,6 +19,7 @@ import (
"math"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
@@ -49,7 +50,7 @@ func newCharacterDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.M
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.CharacterDevice,
DeviceFileMajor: major,
DeviceFileMinor: minor,
@@ -60,7 +61,7 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.CharacterDevice,
DeviceFileMajor: memDevMajor,
DeviceFileMinor: minor,
@@ -72,7 +73,7 @@ func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.M
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Directory,
})
}
@@ -82,7 +83,7 @@ func newSymlink(ctx context.Context, target string, msrc *fs.MountSource) *fs.In
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Symlink,
})
}
@@ -137,7 +138,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Directory,
})
}
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 11a2984d8..77e8d222a 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -17,6 +17,7 @@ package dev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -110,7 +111,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user
}
// Validate flags.
- flags, err := netstack.LinuxToTUNFlags(usermem.ByteOrder.Uint16(req.Data[:]))
+ flags, err := netstack.LinuxToTUNFlags(hostarch.ByteOrder.Uint16(req.Data[:]))
if err != nil {
return 0, err
}
@@ -119,7 +120,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user
case linux.TUNGETIFF:
var req linux.IFReq
copy(req.IFName[:], n.device.Name())
- usermem.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(n.device.Flags()))
+ hostarch.ByteOrder.PutUint16(req.Data[:], netstack.TUNFlagsToLinux(n.device.Flags()))
_, err := req.CopyOut(t, data)
return 0, err
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index c83baf464..2120f2bad 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -40,6 +40,7 @@ go_test(
"//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
+ "//pkg/hostarch",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
"//pkg/syserror",
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
index faeb3908c..ab0e9dac7 100644
--- a/pkg/sentry/fs/fdpipe/pipe_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -27,6 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
func singlePipeFD() (int, error) {
@@ -52,7 +54,7 @@ func mockPipeDirent(t *testing.T) *fs.Dirent {
}
inode := fs.NewInode(ctx, node, fs.NewMockMountSource(nil), fs.StableAttr{
Type: fs.Pipe,
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
})
return fs.NewDirent(ctx, inode, "")
}
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index d388f0e92..6469cc3a9 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -76,6 +76,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -105,6 +106,7 @@ go_test(
library = ":fsutil",
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/safemem",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go
index 2c9446c1d..38383e730 100644
--- a/pkg/sentry/fs/fsutil/dirty_set.go
+++ b/pkg/sentry/fs/fsutil/dirty_set.go
@@ -18,9 +18,9 @@ import (
"math"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/usermem"
)
// DirtySet maps offsets into a memmap.Mappable to DirtyInfo. It is used to
@@ -215,7 +215,7 @@ func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRan
if max < wbr.Start {
break
}
- ims, err := mem.MapInternal(cseg.FileRangeOf(wbr), usermem.Read)
+ ims, err := mem.MapInternal(cseg.FileRangeOf(wbr), hostarch.Read)
if err != nil {
return err
}
diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go
index e3579c23c..48448c97c 100644
--- a/pkg/sentry/fs/fsutil/dirty_set_test.go
+++ b/pkg/sentry/fs/fsutil/dirty_set_test.go
@@ -18,18 +18,18 @@ import (
"reflect"
"testing"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/usermem"
)
func TestDirtySet(t *testing.T) {
var set DirtySet
- set.MarkDirty(memmap.MappableRange{0, 2 * usermem.PageSize})
- set.KeepDirty(memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize})
- set.MarkClean(memmap.MappableRange{0, 2 * usermem.PageSize})
+ set.MarkDirty(memmap.MappableRange{0, 2 * hostarch.PageSize})
+ set.KeepDirty(memmap.MappableRange{hostarch.PageSize, 2 * hostarch.PageSize})
+ set.MarkClean(memmap.MappableRange{0, 2 * hostarch.PageSize})
want := &DirtySegmentDataSlices{
- Start: []uint64{usermem.PageSize},
- End: []uint64{2 * usermem.PageSize},
+ Start: []uint64{hostarch.PageSize},
+ End: []uint64{2 * hostarch.PageSize},
Values: []DirtyInfo{{Keep: true}},
}
if got := set.ExportSortedSlices(); !reflect.DeepEqual(got, want) {
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 1dc409d38..fdaceb1db 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -20,11 +20,11 @@ import (
"math"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/usermem"
)
// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
@@ -130,7 +130,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map
// MemoryFile.AllocateAndFill truncates down to a page
// boundary, but FileRangeSet.Fill is supposed to
// zero-fill to the end of the page in this case.
- donepgaddr, ok := usermem.Addr(done).RoundUp()
+ donepgaddr, ok := hostarch.Addr(done).RoundUp()
if donepg := uint64(donepgaddr); ok && donepg != done {
dsts.DropFirst64(donepg - done)
done = donepg
@@ -184,7 +184,7 @@ func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
// bytes after the new EOF on the same page are zeroed, and pages after the new
// EOF are freed.
func (frs *FileRangeSet) Truncate(end uint64, mf *pgalloc.MemoryFile) {
- pgendaddr, ok := usermem.Addr(end).RoundUp()
+ pgendaddr, ok := hostarch.Addr(end).RoundUp()
if ok {
pgend := uint64(pgendaddr)
@@ -208,7 +208,7 @@ func (frs *FileRangeSet) Truncate(end uint64, mf *pgalloc.MemoryFile) {
if seg.Ok() {
fr := seg.FileRange()
fr.Start += end - seg.Start()
- ims, err := mf.MapInternal(fr, usermem.Write)
+ ims, err := mf.MapInternal(fr, hostarch.Write)
if err != nil {
// There's no good recourse from here. This means
// that we can't keep cached memory consistent with
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 54f7b7cdc..23528bf25 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -18,11 +18,11 @@ import (
"fmt"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// HostFileMapper caches mappings of an arbitrary host file descriptor. It is
@@ -50,13 +50,13 @@ type HostFileMapper struct {
}
const (
- chunkShift = usermem.HugePageShift
+ chunkShift = hostarch.HugePageShift
chunkSize = 1 << chunkShift
chunkMask = chunkSize - 1
)
func pagesInChunk(mr memmap.MappableRange, chunkStart uint64) int32 {
- return int32(mr.Intersect(memmap.MappableRange{chunkStart, chunkStart + chunkSize}).Length() / usermem.PageSize)
+ return int32(mr.Intersect(memmap.MappableRange{chunkStart, chunkStart + chunkSize}).Length() / hostarch.PageSize)
}
type mapping struct {
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index c15d8a946..e1e38b498 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -18,6 +18,7 @@ import (
"math"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -59,7 +60,7 @@ func NewHostMappable(backingFile CachedFileObject) *HostMappable {
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
// Hot path. Avoid defers.
h.mu.Lock()
mapped := h.mappings.AddMapping(ms, ar, offset, writable)
@@ -71,7 +72,7 @@ func (h *HostMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, a
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
// Hot path. Avoid defers.
h.mu.Lock()
unmapped := h.mappings.RemoveMapping(ms, ar, offset, writable)
@@ -82,18 +83,18 @@ func (h *HostMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (h *HostMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (h *HostMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
return h.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
return []memmap.Translation{
{
Source: optional,
File: h,
Offset: optional.Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
},
}, nil
}
@@ -124,7 +125,7 @@ func (h *HostMappable) NotifyChangeFD() error {
}
// MapInternal implements memmap.File.MapInternal.
-func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (h *HostMappable) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) {
return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 0ed7aafa5..7856b354b 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -19,6 +19,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -622,7 +623,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
switch {
case seg.Ok():
// Get internal mappings from the cache.
- ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read)
if err != nil {
unlock()
return done, err
@@ -647,7 +648,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Read into the cache, then re-enter the loop to read from the
// cache.
reqMR := memmap.MappableRange{
- Start: uint64(usermem.Addr(gapMR.Start).RoundDown()),
+ Start: uint64(hostarch.Addr(gapMR.Start).RoundDown()),
End: fs.OffsetPageEnd(int64(gapMR.End)),
}
optMR := gap.Range()
@@ -729,7 +730,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
case seg.Ok() && seg.Start() < mr.End:
// Get internal mappings from the cache.
segMR := seg.Range().Intersect(mr)
- ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write)
+ ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write)
if err != nil {
rw.maybeGrowFile()
rw.c.dataMu.Unlock()
@@ -786,7 +787,7 @@ func (c *CachingInodeOperations) useHostPageCache() bool {
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
// Hot path. Avoid defers.
c.mapsMu.Lock()
mapped := c.mappings.AddMapping(ms, ar, offset, writable)
@@ -808,7 +809,7 @@ func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.Mappi
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
// Hot path. Avoid defers.
c.mapsMu.Lock()
unmapped := c.mappings.RemoveMapping(ms, ar, offset, writable)
@@ -836,12 +837,12 @@ func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Ma
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
return c.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
// Hot path. Avoid defer.
if c.useHostPageCache() {
mr := optional
@@ -853,7 +854,7 @@ func (c *CachingInodeOperations) Translate(ctx context.Context, required, option
Source: mr,
File: c,
Offset: mr.Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
},
}, nil
}
@@ -885,7 +886,7 @@ func (c *CachingInodeOperations) Translate(ctx context.Context, required, option
segMR := seg.Range().Intersect(optional)
// TODO(jamieliu): Make Translations writable even if writability is
// not required if already kept-dirty by another writable translation.
- perms := usermem.AccessType{
+ perms := hostarch.AccessType{
Read: true,
Execute: true,
}
@@ -1050,7 +1051,7 @@ func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) {
// MapInternal implements memmap.File.MapInternal. This is used when we
// directly map an underlying host fd and CachingInodeOperations is used as the
// memmap.File during translation.
-func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) {
return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write)
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index 1547584c5..e107c3096 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -249,7 +250,7 @@ func (f *sliceBackingFile) Allocate(ctx context.Context, offset int64, length in
type noopMappingSpace struct{}
// Invalidate implements memmap.MappingSpace.Invalidate.
-func (noopMappingSpace) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
+func (noopMappingSpace) Invalidate(ar hostarch.AddrRange, opts memmap.InvalidateOpts) {
}
func anonInode(ctx context.Context) *fs.Inode {
@@ -259,14 +260,14 @@ func anonInode(ctx context.Context) *fs.Inode {
}, 0),
}, fs.NewPseudoMountSource(ctx), fs.StableAttr{
Type: fs.Anonymous,
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
})
}
func pagesOf(bs ...byte) []byte {
- buf := make([]byte, 0, len(bs)*usermem.PageSize)
+ buf := make([]byte, 0, len(bs)*hostarch.PageSize)
for _, b := range bs {
- buf = append(buf, bytes.Repeat([]byte{b}, usermem.PageSize)...)
+ buf = append(buf, bytes.Repeat([]byte{b}, hostarch.PageSize)...)
}
return buf
}
@@ -292,28 +293,28 @@ func TestRead(t *testing.T) {
// expects to only cache mapped pages), then call Translate to force it to
// be cached.
var ms noopMappingSpace
- ar := usermem.AddrRange{usermem.PageSize, 2 * usermem.PageSize}
- if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil {
+ ar := hostarch.AddrRange{hostarch.PageSize, 2 * hostarch.PageSize}
+ if err := iops.AddMapping(ctx, ms, ar, hostarch.PageSize, true); err != nil {
t.Fatalf("AddMapping got %v, want nil", err)
}
- mr := memmap.MappableRange{usermem.PageSize, 2 * usermem.PageSize}
- if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil {
+ mr := memmap.MappableRange{hostarch.PageSize, 2 * hostarch.PageSize}
+ if _, err := iops.Translate(ctx, mr, mr, hostarch.Read); err != nil {
t.Fatalf("Translate got %v, want nil", err)
}
- if cached := iops.cache.Span(); cached != usermem.PageSize {
- t.Errorf("SpanRange got %d, want %d", cached, usermem.PageSize)
+ if cached := iops.cache.Span(); cached != hostarch.PageSize {
+ t.Errorf("SpanRange got %d, want %d", cached, hostarch.PageSize)
}
// Try to read 4 pages. The first and third pages should be read directly
// from the "file", the second page should be read from the cache, and only
// 3 pages (the size of the file) should be readable.
- rbuf := make([]byte, 4*usermem.PageSize)
+ rbuf := make([]byte, 4*hostarch.PageSize)
dst := usermem.BytesIOSequence(rbuf)
n, err := iops.Read(ctx, file, dst, 0)
- if n != 3*usermem.PageSize || (err != nil && err != io.EOF) {
- t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*usermem.PageSize)
+ if n != 3*hostarch.PageSize || (err != nil && err != io.EOF) {
+ t.Fatalf("Read got (%d, %v), want (%d, nil or EOF)", n, err, 3*hostarch.PageSize)
}
- rbuf = rbuf[:3*usermem.PageSize]
+ rbuf = rbuf[:3*hostarch.PageSize]
// Did we get the bytes we expect?
if !bytes.Equal(rbuf, buf) {
@@ -323,7 +324,7 @@ func TestRead(t *testing.T) {
// Delete the memory mapping before iops.Release(). The cached page will
// either be evicted by ctx's pgalloc.MemoryFile, or dropped by
// iops.Release().
- iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true)
+ iops.RemoveMapping(ctx, ms, ar, hostarch.PageSize, true)
}
func TestWrite(t *testing.T) {
@@ -348,25 +349,25 @@ func TestWrite(t *testing.T) {
// CachingInodeOperations expects to only cache mapped pages), then call
// Translate to force them to be cached.
var ms noopMappingSpace
- ar := usermem.AddrRange{usermem.PageSize, 3 * usermem.PageSize}
- if err := iops.AddMapping(ctx, ms, ar, usermem.PageSize, true); err != nil {
+ ar := hostarch.AddrRange{hostarch.PageSize, 3 * hostarch.PageSize}
+ if err := iops.AddMapping(ctx, ms, ar, hostarch.PageSize, true); err != nil {
t.Fatalf("AddMapping got %v, want nil", err)
}
- defer iops.RemoveMapping(ctx, ms, ar, usermem.PageSize, true)
- mr := memmap.MappableRange{usermem.PageSize, 3 * usermem.PageSize}
- if _, err := iops.Translate(ctx, mr, mr, usermem.Read); err != nil {
+ defer iops.RemoveMapping(ctx, ms, ar, hostarch.PageSize, true)
+ mr := memmap.MappableRange{hostarch.PageSize, 3 * hostarch.PageSize}
+ if _, err := iops.Translate(ctx, mr, mr, hostarch.Read); err != nil {
t.Fatalf("Translate got %v, want nil", err)
}
- if cached := iops.cache.Span(); cached != 2*usermem.PageSize {
- t.Errorf("SpanRange got %d, want %d", cached, 2*usermem.PageSize)
+ if cached := iops.cache.Span(); cached != 2*hostarch.PageSize {
+ t.Errorf("SpanRange got %d, want %d", cached, 2*hostarch.PageSize)
}
// Write to the first 2 pages.
wbuf := pagesOf('e', 'f')
src := usermem.BytesIOSequence(wbuf)
n, err := iops.Write(ctx, src, 0)
- if n != 2*usermem.PageSize || err != nil {
- t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*usermem.PageSize)
+ if n != 2*hostarch.PageSize || err != nil {
+ t.Fatalf("Write got (%d, %v), want (%d, nil)", n, err, 2*hostarch.PageSize)
}
// The first page should have been written directly, since it was not cached.
@@ -382,7 +383,7 @@ func TestWrite(t *testing.T) {
}
// Now the second page should have been written as well.
- copy(want[usermem.PageSize:], pagesOf('f'))
+ copy(want[hostarch.PageSize:], pagesOf('f'))
if !bytes.Equal(buf, want) {
t.Errorf("File contents are %v, want %v", buf, want)
}
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index b210e0e7e..c4a069832 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -27,6 +27,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fd",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/p9",
"//pkg/refs",
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
index cffc756cc..d6bff3f40 100644
--- a/pkg/sentry/fs/gofer/attr.go
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -17,11 +17,11 @@ package gofer
import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/usermem"
)
// getattr returns the 9p attributes of the p9.File. On success, Mode, Size, and RDev
@@ -98,7 +98,7 @@ func bsize(pattr p9.Attr) int64 {
// Some files, particularly those that are not on a local file system,
// may have no clue of their block size. Better not to report something
// misleading or buggy and have a safe default.
- return usermem.PageSize
+ return hostarch.PageSize
}
// ntype returns an fs.InodeType from 9p attributes.
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 0b3d0617f..46a2dc47d 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -384,8 +384,16 @@ func (c *ConnectedEndpoint) CloseUnread() {}
// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
- // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
- // sockets.
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix
+ // domain sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
+// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize.
+func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_RCVBUF for host backed unix
+ // domain sockets. Receive buffer does not have any effect for unix
+ // sockets and we claim to be the same as send buffer.
return atomic.LoadInt64(&c.sndbuf)
}
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index fb81d903d..1b83643db 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
@@ -216,7 +217,7 @@ func (i *Inotify) Ioctl(ctx context.Context, _ *File, io usermem.IO, args arch.S
n += uint32(e.sizeOf())
}
var buf [4]byte
- usermem.ByteOrder.PutUint32(buf[:], n)
+ hostarch.ByteOrder.PutUint32(buf[:], n)
_, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{})
return 0, err
diff --git a/pkg/sentry/fs/inotify_event.go b/pkg/sentry/fs/inotify_event.go
index 686e1b1cd..399aff1ed 100644
--- a/pkg/sentry/fs/inotify_event.go
+++ b/pkg/sentry/fs/inotify_event.go
@@ -19,6 +19,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -100,10 +101,10 @@ func (e *Event) sizeOf() int {
// construct the output. We use a buffer allocated ahead of time for
// performance. buf must be at least inotifyEventBaseSize bytes.
func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) {
- usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
- usermem.ByteOrder.PutUint32(buf[4:], e.mask)
- usermem.ByteOrder.PutUint32(buf[8:], e.cookie)
- usermem.ByteOrder.PutUint32(buf[12:], e.len)
+ hostarch.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
+ hostarch.ByteOrder.PutUint32(buf[4:], e.mask)
+ hostarch.ByteOrder.PutUint32(buf[8:], e.cookie)
+ hostarch.ByteOrder.PutUint32(buf[12:], e.len)
writeLen := 0
diff --git a/pkg/sentry/fs/offset.go b/pkg/sentry/fs/offset.go
index 53b5df175..3a8c97d8f 100644
--- a/pkg/sentry/fs/offset.go
+++ b/pkg/sentry/fs/offset.go
@@ -17,14 +17,14 @@ package fs
import (
"math"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// OffsetPageEnd returns the file offset rounded up to the nearest
// page boundary. OffsetPageEnd panics if rounding up causes overflow,
// which shouldn't be possible given that offset is an int64.
func OffsetPageEnd(offset int64) uint64 {
- end, ok := usermem.Addr(offset).RoundUp()
+ end, ok := hostarch.Addr(offset).RoundUp()
if !ok {
panic("impossible overflow")
}
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
index 01a1235b8..f96f5a3e5 100644
--- a/pkg/sentry/fs/overlay.go
+++ b/pkg/sentry/fs/overlay.go
@@ -19,11 +19,11 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// The virtual filesystem implements an overlay configuration. For a high-level
@@ -274,7 +274,7 @@ func (o *overlayEntry) markDirectoryDirty() {
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
o.mapsMu.Lock()
defer o.mapsMu.Unlock()
if err := o.inodeLocked().Mappable().AddMapping(ctx, ms, ar, offset, writable); err != nil {
@@ -285,7 +285,7 @@ func (o *overlayEntry) AddMapping(ctx context.Context, ms memmap.MappingSpace, a
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
o.mapsMu.Lock()
defer o.mapsMu.Unlock()
o.inodeLocked().Mappable().RemoveMapping(ctx, ms, ar, offset, writable)
@@ -293,7 +293,7 @@ func (o *overlayEntry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
o.mapsMu.Lock()
defer o.mapsMu.Unlock()
if err := o.inodeLocked().Mappable().CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil {
@@ -304,7 +304,7 @@ func (o *overlayEntry) CopyMapping(ctx context.Context, ms memmap.MappingSpace,
}
// Translate implements memmap.Mappable.Translate.
-func (o *overlayEntry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (o *overlayEntry) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
o.dataMu.RLock()
defer o.dataMu.RUnlock()
return o.inodeLocked().Mappable().Translate(ctx, required, optional, at)
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index b8b2281a8..7af7e0b45 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -30,6 +30,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
index e6171dd1d..24426b225 100644
--- a/pkg/sentry/fs/proc/exec_args.go
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -113,7 +114,7 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
defer m.DecUsers(ctx)
// Figure out the bounds of the exec arg we are trying to read.
- var execArgStart, execArgEnd usermem.Addr
+ var execArgStart, execArgEnd hostarch.Addr
switch f.arg {
case cmdlineExecArg:
execArgStart, execArgEnd = m.ArgvStart(), m.ArgvEnd()
@@ -172,8 +173,8 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
// https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
// we'll return one page total between argv and envp because of the
// above page restrictions.
- if lengthEnvv > usermem.PageSize-len(buf) {
- lengthEnvv = usermem.PageSize - len(buf)
+ if lengthEnvv > hostarch.PageSize-len(buf) {
+ lengthEnvv = hostarch.PageSize - len(buf)
}
// Make a new buffer to fit the whole thing
tmp := make([]byte, length+lengthEnvv)
diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go
index d2859a4c2..78132f7a5 100644
--- a/pkg/sentry/fs/proc/inode.go
+++ b/pkg/sentry/fs/proc/inode.go
@@ -17,13 +17,13 @@ package proc
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/usermem"
)
// LINT.IfChange
@@ -125,7 +125,7 @@ func newProcInode(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
sattr := fs.StableAttr{
DeviceID: device.ProcDevice.DeviceID(),
InodeID: device.ProcDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: typ,
}
if t != nil {
diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go
index 91617267d..7d975d333 100644
--- a/pkg/sentry/fs/proc/meminfo.go
+++ b/pkg/sentry/fs/proc/meminfo.go
@@ -19,10 +19,10 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/usermem"
)
// LINT.IfChange
@@ -53,7 +53,7 @@ func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
anon := snapshot.Anonymous + snapshot.Tmpfs
file := snapshot.PageCache + snapshot.Mapped
// We don't actually have active/inactive LRUs, so just make up numbers.
- activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ activeFile := (file / 2) &^ (hostarch.PageSize - 1)
inactiveFile := file - activeFile
var buf bytes.Buffer
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 203cfa061..91c35eea9 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
@@ -35,7 +36,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/usermem"
)
// LINT.IfChange
@@ -367,10 +367,10 @@ func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]
)
if len(rt.GatewayAddr) == header.IPv4AddressSize {
flags |= linux.RTF_GATEWAY
- gw = usermem.ByteOrder.Uint32(rt.GatewayAddr)
+ gw = hostarch.ByteOrder.Uint32(rt.GatewayAddr)
}
if len(rt.DstAddr) == header.IPv4AddressSize {
- prefix = usermem.ByteOrder.Uint32(rt.DstAddr)
+ prefix = hostarch.ByteOrder.Uint32(rt.DstAddr)
}
l := fmt.Sprintf(
"%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d",
@@ -520,7 +520,7 @@ func networkToHost16(n uint16) uint16 {
// binary.BigEndian.Uint16() require a read of binary.BigEndian and an
// interface method call, defeating inlining.
buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
- return usermem.ByteOrder.Uint16(buf[:])
+ return hostarch.ByteOrder.Uint16(buf[:])
}
func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
@@ -542,14 +542,14 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
// __be32 which is a typedef for an unsigned int, and is printed with
// %X. This means that for a little-endian machine, Linux prints the
// least-significant byte of the address first. To emulate this, we first
- // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // invert the byte order for the address using hostarch.ByteOrder.Uint32,
// which makes it have the equivalent encoding to a __be32 on a little
// endian machine. Note that this operation is a no-op on a big endian
// machine. Then similar to Linux, we format it with %X, which will print
// the most-significant byte of the __be32 address first, which is now
// actually the least-significant byte of the original address in
// linux.SockAddrInet.Addr on little endian machines, due to the conversion.
- addr := usermem.ByteOrder.Uint32(a.Addr[:])
+ addr := hostarch.ByteOrder.Uint32(a.Addr[:])
fmt.Fprintf(w, "%08X:%04X ", addr, port)
case linux.AF_INET6:
@@ -559,10 +559,10 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
}
port := networkToHost16(a.Port)
- addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4])
- addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8])
- addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12])
- addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16])
+ addr0 := hostarch.ByteOrder.Uint32(a.Addr[0:4])
+ addr1 := hostarch.ByteOrder.Uint32(a.Addr[4:8])
+ addr2 := hostarch.ByteOrder.Uint32(a.Addr[8:12])
+ addr3 := hostarch.ByteOrder.Uint32(a.Addr[12:16])
fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port)
}
}
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index 21338d912..713b81e08 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/proc/device",
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go
index 6121f0e95..b01688b1d 100644
--- a/pkg/sentry/fs/proc/seqfile/seqfile.go
+++ b/pkg/sentry/fs/proc/seqfile/seqfile.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
@@ -131,7 +132,7 @@ func NewSeqFileInode(ctx context.Context, source SeqSource, msrc *fs.MountSource
sattr := fs.StableAttr{
DeviceID: device.ProcDevice.DeviceID(),
InodeID: device.ProcDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialFile,
}
return fs.NewInode(ctx, iops, msrc, sattr)
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index bbe282c03..1d09afdd7 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
@@ -76,7 +77,7 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir
sattr := fs.StableAttr{
DeviceID: device.ProcDevice.DeviceID(),
InodeID: device.ProcDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialFile,
}
return fs.NewInode(ctx, tm, msrc, sattr)
@@ -136,7 +137,7 @@ func (f *tcpMemFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequen
f.tcpMemInode.mu.Lock()
defer f.tcpMemInode.mu.Unlock()
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
size, err := readSize(f.tcpMemInode.dir, f.tcpMemInode.s)
if err != nil {
return 0, err
@@ -192,7 +193,7 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f
sattr := fs.StableAttr{
DeviceID: device.ProcDevice.DeviceID(),
InodeID: device.ProcDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialFile,
}
return fs.NewInode(ctx, ts, msrc, sattr)
@@ -264,7 +265,7 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque
// Only consider size of one memory page for input for performance reasons.
// We are only reading if it's zero or not anyway.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
@@ -294,7 +295,7 @@ func newTCPRecoveryInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack
sattr := fs.StableAttr{
DeviceID: device.ProcDevice.DeviceID(),
InodeID: device.ProcDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialFile,
}
return fs.NewInode(ctx, ts, msrc, sattr)
@@ -354,7 +355,7 @@ func (f *tcpRecoveryFile) Write(ctx context.Context, _ *fs.File, src usermem.IOS
if src.NumBytes() == 0 {
return 0, nil
}
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
@@ -413,7 +414,7 @@ func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stac
sattr := fs.StableAttr{
DeviceID: device.ProcDevice.DeviceID(),
InodeID: device.ProcDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialFile,
}
return fs.NewInode(ctx, ipf, msrc, sattr)
@@ -486,7 +487,7 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO
// Only consider size of one memory page for input for performance reasons.
// We are only reading if it's zero or not anyway.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
@@ -524,7 +525,7 @@ func newPortRangeInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack)
sattr := fs.StableAttr{
DeviceID: device.ProcDevice.DeviceID(),
InodeID: device.ProcDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialFile,
}
return fs.NewInode(ctx, ipf, msrc, sattr)
@@ -589,7 +590,7 @@ func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSe
// Only consider size of one memory page for input for performance
// reasons.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
ports := make([]int32, 2)
n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts)
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index f43d6c221..ae5ed25f9 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
@@ -469,7 +470,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
defer mm.DecUsers(ctx)
// Buffer the read data because of MM locks
buf := make([]byte, dst.NumBytes())
- n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ n, readErr := mm.CopyIn(ctx, hostarch.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
if n > 0 {
if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
return 0, syserror.EFAULT
@@ -632,7 +633,7 @@ func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
rss = mm.ResidentSetSize()
}
})
- fmt.Fprintf(&buf, "%d %d ", vss, rss/usermem.PageSize)
+ fmt.Fprintf(&buf, "%d %d ", vss, rss/hostarch.PageSize)
// rsslim.
fmt.Fprintf(&buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
@@ -684,7 +685,7 @@ func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([
})
var buf bytes.Buffer
- fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+ fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize)
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statmData)(nil)}}, 0
}
@@ -939,8 +940,8 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
buf := make([]byte, size)
for i, e := range auxv {
- usermem.ByteOrder.PutUint64(buf[16*i:], e.Key)
- usermem.ByteOrder.PutUint64(buf[16*i+8:], uint64(e.Value))
+ hostarch.ByteOrder.PutUint64(buf[16*i:], e.Key)
+ hostarch.ByteOrder.PutUint64(buf[16*i+8:], uint64(e.Value))
}
n, err := dst.CopyOut(ctx, buf[offset:])
@@ -1020,7 +1021,7 @@ func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOS
}
// Limit input size so as not to impact performance if input size is large.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go
index 2bc9485d8..30d5ad4cf 100644
--- a/pkg/sentry/fs/proc/uid_gid_map.go
+++ b/pkg/sentry/fs/proc/uid_gid_map.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -132,7 +133,7 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u
// the system page size, and the write must be performed at the start of
// the file ..." - user_namespaces(7)
srclen := src.NumBytes()
- if srclen >= usermem.PageSize || offset != 0 {
+ if srclen >= hostarch.PageSize || offset != 0 {
return 0, syserror.EINVAL
}
b := make([]byte, srclen)
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index a51d00d86..4a3d9636b 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -14,13 +14,13 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/socket/unix/transport",
"//pkg/sync",
"//pkg/syserror",
- "//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/fs/ramfs/tree.go b/pkg/sentry/fs/ramfs/tree.go
index dfc9d3453..0ace636c9 100644
--- a/pkg/sentry/fs/ramfs/tree.go
+++ b/pkg/sentry/fs/ramfs/tree.go
@@ -20,9 +20,9 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
- "gvisor.dev/gvisor/pkg/usermem"
)
// MakeDirectoryTree constructs a ramfs tree of all directories containing
@@ -71,7 +71,7 @@ func emptyDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
return fs.NewInode(ctx, dir, msrc, fs.StableAttr{
DeviceID: anon.PseudoDevice.DeviceID(),
InodeID: anon.PseudoDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Directory,
})
}
diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD
index f2e8b9932..fdbc5f912 100644
--- a/pkg/sentry/fs/sys/BUILD
+++ b/pkg/sentry/fs/sys/BUILD
@@ -14,11 +14,11 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/kernel",
- "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/sys/sys.go b/pkg/sentry/fs/sys/sys.go
index 0891645e4..101779a7a 100644
--- a/pkg/sentry/fs/sys/sys.go
+++ b/pkg/sentry/fs/sys/sys.go
@@ -17,16 +17,16 @@ package sys
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
- "gvisor.dev/gvisor/pkg/usermem"
)
func newFile(ctx context.Context, node fs.InodeOperations, msrc *fs.MountSource) *fs.Inode {
sattr := fs.StableAttr{
DeviceID: sysfsDevice.DeviceID(),
InodeID: sysfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialFile,
}
return fs.NewInode(ctx, node, msrc, sattr)
@@ -37,7 +37,7 @@ func newDir(ctx context.Context, msrc *fs.MountSource, contents map[string]*fs.I
return fs.NewInode(ctx, d, msrc, fs.StableAttr{
DeviceID: sysfsDevice.DeviceID(),
InodeID: sysfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.SpecialDirectory,
})
}
diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD
index d16cdb4df..c7977a217 100644
--- a/pkg/sentry/fs/timerfd/BUILD
+++ b/pkg/sentry/fs/timerfd/BUILD
@@ -8,6 +8,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 46511a6ac..c8ebe256c 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -20,6 +20,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -124,7 +125,7 @@ func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.I
}
if val := atomic.SwapUint64(&t.val, 0); val != 0 {
var buf [sizeofUint64]byte
- usermem.ByteOrder.PutUint64(buf[:], val)
+ hostarch.ByteOrder.PutUint64(buf[:], val)
if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
// Linux does not undo consuming the number of expirations even if
// writing to userspace fails.
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index b521a86a2..90398376a 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -15,6 +15,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/safemem",
"//pkg/sentry/device",
"//pkg/sentry/fs",
@@ -42,6 +43,7 @@ go_test(
library = ":tmpfs",
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/contexttest",
"//pkg/sentry/usage",
diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go
index d4d613ea9..1718f9372 100644
--- a/pkg/sentry/fs/tmpfs/file_test.go
+++ b/pkg/sentry/fs/tmpfs/file_test.go
@@ -19,6 +19,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/usage"
@@ -31,7 +32,7 @@ func newFileInode(ctx context.Context) *fs.Inode {
return fs.NewInode(ctx, iops, m, fs.StableAttr{
DeviceID: tmpfsDevice.DeviceID(),
InodeID: tmpfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.RegularFile,
})
}
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index ad4aea282..f4de8c968 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -125,7 +126,7 @@ func NewMemfdInode(ctx context.Context, allowSeals bool) *fs.Inode {
Type: fs.RegularFile,
DeviceID: tmpfsDevice.DeviceID(),
InodeID: tmpfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
})
}
@@ -392,7 +393,7 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
switch {
case seg.Ok():
// Get internal mappings.
- ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read)
if err != nil {
return done, err
}
@@ -463,7 +464,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
//
// See Linux, mm/filemap.c:generic_perform_write() and
// mm/shmem.c:shmem_write_begin().
- if pgstart := int64(usermem.Addr(rw.f.attr.Size).RoundDown()); end > pgstart {
+ if pgstart := int64(hostarch.Addr(rw.f.attr.Size).RoundDown()); end > pgstart {
end = pgstart
}
if end <= rw.offset {
@@ -483,8 +484,8 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
mf := rw.f.kernel.MemoryFile()
// Page-aligned mr for when we need to allocate memory. RoundUp can't
// overflow since end is an int64.
- pgstartaddr := usermem.Addr(rw.offset).RoundDown()
- pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgstartaddr := hostarch.Addr(rw.offset).RoundDown()
+ pgendaddr, _ := hostarch.Addr(end).RoundUp()
pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
var done uint64
@@ -494,7 +495,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
switch {
case seg.Ok():
// Get internal mappings.
- ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Write)
if err != nil {
return done, err
}
@@ -527,7 +528,7 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
f.mapsMu.Lock()
defer f.mapsMu.Unlock()
@@ -544,7 +545,7 @@ func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingS
pagesBefore := f.writableMappingPages
// ar is guaranteed to be page aligned per memmap.Mappable.
- f.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+ f.writableMappingPages += uint64(ar.Length() / hostarch.PageSize)
if f.writableMappingPages < pagesBefore {
panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages))
@@ -555,7 +556,7 @@ func (f *fileInodeOperations) AddMapping(ctx context.Context, ms memmap.MappingS
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
f.mapsMu.Lock()
defer f.mapsMu.Unlock()
@@ -565,7 +566,7 @@ func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Mappi
pagesBefore := f.writableMappingPages
// ar is guaranteed to be page aligned per memmap.Mappable.
- f.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+ f.writableMappingPages -= uint64(ar.Length() / hostarch.PageSize)
if f.writableMappingPages > pagesBefore {
panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, f.writableMappingPages))
@@ -574,12 +575,12 @@ func (f *fileInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Mappi
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (f *fileInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (f *fileInodeOperations) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
return f.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-func (f *fileInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (f *fileInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
f.dataMu.Lock()
defer f.dataMu.Unlock()
@@ -612,7 +613,7 @@ func (f *fileInodeOperations) Translate(ctx context.Context, required, optional
Source: segMR,
File: mf,
Offset: seg.FileRangeOf(segMR).Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
})
translatedEnd = segMR.End
}
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index cf4ed5de0..577052888 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
var fsInfo = fs.Info{
@@ -41,8 +41,8 @@ var fsInfo = fs.Info{
// chosen to ensure that BlockSize * Blocks does not overflow int64 (which
// applications may also handle incorrectly).
// TODO(b/29637826): allow configuring a tmpfs size and enforce it.
- TotalBlocks: math.MaxInt64 / usermem.PageSize,
- FreeBlocks: math.MaxInt64 / usermem.PageSize,
+ TotalBlocks: math.MaxInt64 / hostarch.PageSize,
+ FreeBlocks: math.MaxInt64 / hostarch.PageSize,
}
// rename implements fs.InodeOperations.Rename for tmpfs nodes.
@@ -99,7 +99,7 @@ func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwn
return fs.NewInode(ctx, d, msrc, fs.StableAttr{
DeviceID: tmpfsDevice.DeviceID(),
InodeID: tmpfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Directory,
})
}
@@ -232,7 +232,7 @@ func (d *Dir) newCreateOps() *ramfs.CreateOps {
return fs.NewInode(ctx, iops, dir.MountSource, fs.StableAttr{
DeviceID: tmpfsDevice.DeviceID(),
InodeID: tmpfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.RegularFile,
}), nil
},
@@ -281,7 +281,7 @@ func NewSymlink(ctx context.Context, target string, owner fs.FileOwner, msrc *fs
return fs.NewInode(ctx, s, msrc, fs.StableAttr{
DeviceID: tmpfsDevice.DeviceID(),
InodeID: tmpfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Symlink,
})
}
@@ -311,7 +311,7 @@ func NewSocket(ctx context.Context, socket transport.BoundEndpoint, owner fs.Fil
return fs.NewInode(ctx, s, msrc, fs.StableAttr{
DeviceID: tmpfsDevice.DeviceID(),
InodeID: tmpfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Socket,
})
}
@@ -348,7 +348,7 @@ func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions,
return fs.NewInode(ctx, fifoIops, msrc, fs.StableAttr{
DeviceID: tmpfsDevice.DeviceID(),
InodeID: tmpfsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Pipe,
})
}
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index e6d0eb359..86ada820e 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -17,6 +17,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index c2da80bc2..13c9dbe7d 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -122,7 +123,7 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode {
// TODO(b/75267214): Since ptsDevice must be shared between
// different mounts, we must not assign fixed numbers.
InodeID: ptsDevice.NextIno(),
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
Type: fs.Directory,
})
}
diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD
new file mode 100644
index 000000000..37efb641a
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/BUILD
@@ -0,0 +1,48 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dir_refs",
+ out = "dir_refs.go",
+ package = "cgroupfs",
+ prefix = "dir",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "dir",
+ },
+)
+
+go_library(
+ name = "cgroupfs",
+ srcs = [
+ "base.go",
+ "cgroupfs.go",
+ "cpu.go",
+ "cpuacct.go",
+ "cpuset.go",
+ "dir_refs.go",
+ "job.go",
+ "memory.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/coverage",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/refsvfs2",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go
new file mode 100644
index 000000000..0f54888d8
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/base.go
@@ -0,0 +1,261 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strconv"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// controllerCommon implements kernel.CgroupController.
+//
+// Must call init before use.
+//
+// +stateify savable
+type controllerCommon struct {
+ ty kernel.CgroupControllerType
+ fs *filesystem
+}
+
+func (c *controllerCommon) init(ty kernel.CgroupControllerType, fs *filesystem) {
+ c.ty = ty
+ c.fs = fs
+}
+
+// Type implements kernel.CgroupController.Type.
+func (c *controllerCommon) Type() kernel.CgroupControllerType {
+ return kernel.CgroupControllerType(c.ty)
+}
+
+// HierarchyID implements kernel.CgroupController.HierarchyID.
+func (c *controllerCommon) HierarchyID() uint32 {
+ return c.fs.hierarchyID
+}
+
+// NumCgroups implements kernel.CgroupController.NumCgroups.
+func (c *controllerCommon) NumCgroups() uint64 {
+ return atomic.LoadUint64(&c.fs.numCgroups)
+}
+
+// Enabled implements kernel.CgroupController.Enabled.
+//
+// Controllers are currently always enabled.
+func (c *controllerCommon) Enabled() bool {
+ return true
+}
+
+// Filesystem implements kernel.CgroupController.Filesystem.
+func (c *controllerCommon) Filesystem() *vfs.Filesystem {
+ return c.fs.VFSFilesystem()
+}
+
+// RootCgroup implements kernel.CgroupController.RootCgroup.
+func (c *controllerCommon) RootCgroup() kernel.Cgroup {
+ return c.fs.rootCgroup()
+}
+
+// controller is an interface for common functionality related to all cgroups.
+// It is an extension of the public cgroup interface, containing cgroup
+// functionality private to cgroupfs.
+type controller interface {
+ kernel.CgroupController
+
+ // AddControlFiles should extend the contents map with inodes representing
+ // control files defined by this controller.
+ AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode)
+}
+
+// cgroupInode implements kernel.CgroupImpl and kernfs.Inode.
+//
+// +stateify savable
+type cgroupInode struct {
+ dir
+ fs *filesystem
+
+ // ts is the list of tasks in this cgroup. The kernel is responsible for
+ // removing tasks from this list before they're destroyed, so any tasks on
+ // this list are always valid.
+ //
+ // ts, and cgroup membership in general is protected by fs.tasksMu.
+ ts map[*kernel.Task]struct{}
+}
+
+var _ kernel.CgroupImpl = (*cgroupInode)(nil)
+
+func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
+ c := &cgroupInode{
+ fs: fs,
+ ts: make(map[*kernel.Task]struct{}),
+ }
+
+ contents := make(map[string]kernfs.Inode)
+ contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c})
+ contents["tasks"] = fs.newControllerFile(ctx, creds, &tasksData{c})
+
+ for _, ctl := range fs.controllers {
+ ctl.AddControlFiles(ctx, creds, c, contents)
+ }
+
+ c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555))
+ c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ c.dir.InitRefs()
+ c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents))
+
+ atomic.AddUint64(&fs.numCgroups, 1)
+
+ return c
+}
+
+func (c *cgroupInode) HierarchyID() uint32 {
+ return c.fs.hierarchyID
+}
+
+// Controllers implements kernel.CgroupImpl.Controllers.
+func (c *cgroupInode) Controllers() []kernel.CgroupController {
+ return c.fs.kcontrollers
+}
+
+// Enter implements kernel.CgroupImpl.Enter.
+func (c *cgroupInode) Enter(t *kernel.Task) {
+ c.fs.tasksMu.Lock()
+ c.ts[t] = struct{}{}
+ c.fs.tasksMu.Unlock()
+}
+
+// Leave implements kernel.CgroupImpl.Leave.
+func (c *cgroupInode) Leave(t *kernel.Task) {
+ c.fs.tasksMu.Lock()
+ delete(c.ts, t)
+ c.fs.tasksMu.Unlock()
+}
+
+func sortTIDs(tids []kernel.ThreadID) {
+ sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] })
+}
+
+// +stateify savable
+type cgroupProcsData struct {
+ *cgroupInode
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ t := kernel.TaskFromContext(ctx)
+ currPidns := t.ThreadGroup().PIDNamespace()
+
+ pgids := make(map[kernel.ThreadID]struct{})
+
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+
+ for task := range d.ts {
+ // Map dedups pgid, since iterating over all tasks produces multiple
+ // entries for the group leaders.
+ if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 {
+ pgids[pgid] = struct{}{}
+ }
+ }
+
+ pgidList := make([]kernel.ThreadID, 0, len(pgids))
+ for pgid, _ := range pgids {
+ pgidList = append(pgidList, pgid)
+ }
+ sortTIDs(pgidList)
+
+ for _, pgid := range pgidList {
+ fmt.Fprintf(buf, "%d\n", pgid)
+ }
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
+ return src.NumBytes(), nil
+}
+
+// +stateify savable
+type tasksData struct {
+ *cgroupInode
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ t := kernel.TaskFromContext(ctx)
+ currPidns := t.ThreadGroup().PIDNamespace()
+
+ var pids []kernel.ThreadID
+
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+
+ for task := range d.ts {
+ if pid := currPidns.IDOfTask(task); pid != 0 {
+ pids = append(pids, pid)
+ }
+ }
+ sortTIDs(pids)
+
+ for _, pid := range pids {
+ fmt.Fprintf(buf, "%d\n", pid)
+ }
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
+ return src.NumBytes(), nil
+}
+
+// parseInt64FromString interprets src as string encoding a int64 value, and
+// returns the parsed value.
+func parseInt64FromString(ctx context.Context, src usermem.IOSequence, offset int64) (val, len int64, err error) {
+ const maxInt64StrLen = 20 // i.e. len(fmt.Sprintf("%d", math.MinInt64)) == 20
+
+ t := kernel.TaskFromContext(ctx)
+ src = src.DropFirst64(offset)
+
+ buf := t.CopyScratchBuffer(maxInt64StrLen)
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ return 0, int64(n), err
+ }
+ buf = buf[:n]
+
+ val, err = strconv.ParseInt(string(buf), 10, 64)
+ if err != nil {
+ // Note: This also handles zero-len writes if offset is beyond the end
+ // of src, or src is empty.
+ ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err)
+ return 0, int64(n), syserror.EINVAL
+ }
+
+ return val, int64(n), nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
new file mode 100644
index 000000000..bd3e69757
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -0,0 +1,425 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cgroupfs implements cgroupfs.
+//
+// A cgroup is a collection of tasks on the system, organized into a tree-like
+// structure similar to a filesystem directory tree. In fact, each cgroup is
+// represented by a directory on cgroupfs, and is manipulated through control
+// files in the directory.
+//
+// All cgroups on a system are organized into hierarchies. Hierarchies are a
+// distinct tree of cgroups, with a common set of controllers. One or more
+// cgroupfs mounts may point to each hierarchy. These mounts provide a common
+// view into the same tree of cgroups.
+//
+// A controller (also known as a "resource controller", or a cgroup "subsystem")
+// determines the behaviour of each cgroup.
+//
+// In addition to cgroupfs, the kernel has a cgroup registry that tracks
+// system-wide state related to cgroups such as active hierarchies and the
+// controllers associated with them.
+//
+// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between
+// cgroupfs dentries and inodes.
+//
+// # Synchronization
+//
+// Cgroup hierarchy creation and destruction is protected by the
+// kernel.CgroupRegistry.mu. Once created, a hierarchy's set of controllers, the
+// filesystem associated with it, and the root cgroup for the hierarchy are
+// immutable.
+//
+// Membership of tasks within cgroups is protected by
+// cgroupfs.filesystem.tasksMu. Tasks also maintain a set of all cgroups they're
+// in, and this list is protected by Task.mu.
+//
+// Lock order:
+//
+// kernel.CgroupRegistry.mu
+// cgroupfs.filesystem.mu
+// Task.mu
+// cgroupfs.filesystem.tasksMu.
+package cgroupfs
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // Name is the default filesystem name.
+ Name = "cgroup"
+ readonlyFileMode = linux.FileMode(0444)
+ writableFileMode = linux.FileMode(0644)
+ defaultMaxCachedDentries = uint64(1000)
+)
+
+const (
+ controllerCPU = kernel.CgroupControllerType("cpu")
+ controllerCPUAcct = kernel.CgroupControllerType("cpuacct")
+ controllerCPUSet = kernel.CgroupControllerType("cpuset")
+ controllerJob = kernel.CgroupControllerType("job")
+ controllerMemory = kernel.CgroupControllerType("memory")
+)
+
+var allControllers = []kernel.CgroupControllerType{
+ controllerCPU,
+ controllerCPUAcct,
+ controllerCPUSet,
+ controllerJob,
+ controllerMemory,
+}
+
+// SupportedMountOptions is the set of supported mount options for cgroupfs.
+var SupportedMountOptions = []string{"all", "cpu", "cpuacct", "cpuset", "job", "memory"}
+
+// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
+type FilesystemType struct{}
+
+// InternalData contains internal data passed in to the cgroupfs mount via
+// vfs.GetFilesystemOptions.InternalData.
+//
+// +stateify savable
+type InternalData struct {
+ DefaultControlValues map[string]int64
+}
+
+// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // hierarchyID is the id the cgroup registry assigns to this hierarchy. Has
+ // the value kernel.InvalidCgroupHierarchyID until the FS is fully
+ // initialized.
+ //
+ // hierarchyID is immutable after initialization.
+ hierarchyID uint32
+
+ // controllers and kcontrollers are both the list of controllers attached to
+ // this cgroupfs. Both lists are the same set of controllers, but typecast
+ // to different interfaces for convenience. Both must stay in sync, and are
+ // immutable.
+ controllers []controller
+ kcontrollers []kernel.CgroupController
+
+ numCgroups uint64 // Protected by atomic ops.
+
+ root *kernfs.Dentry
+
+ // tasksMu serializes task membership changes across all cgroups within a
+ // filesystem.
+ tasksMu sync.RWMutex `state:"nosave"`
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
+ var wantControllers []kernel.CgroupControllerType
+ if _, ok := mopts["cpu"]; ok {
+ delete(mopts, "cpu")
+ wantControllers = append(wantControllers, controllerCPU)
+ }
+ if _, ok := mopts["cpuacct"]; ok {
+ delete(mopts, "cpuacct")
+ wantControllers = append(wantControllers, controllerCPUAcct)
+ }
+ if _, ok := mopts["cpuset"]; ok {
+ delete(mopts, "cpuset")
+ wantControllers = append(wantControllers, controllerCPUSet)
+ }
+ if _, ok := mopts["job"]; ok {
+ delete(mopts, "job")
+ wantControllers = append(wantControllers, controllerJob)
+ }
+ if _, ok := mopts["memory"]; ok {
+ delete(mopts, "memory")
+ wantControllers = append(wantControllers, controllerMemory)
+ }
+ if _, ok := mopts["all"]; ok {
+ if len(wantControllers) > 0 {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers)
+ return nil, nil, syserror.EINVAL
+ }
+
+ delete(mopts, "all")
+ wantControllers = allControllers
+ }
+
+ if len(wantControllers) == 0 {
+ // Specifying no controllers implies all controllers.
+ wantControllers = allControllers
+ }
+
+ if len(mopts) != 0 {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ k := kernel.KernelFromContext(ctx)
+ r := k.CgroupRegistry()
+
+ // "It is not possible to mount the same controller against multiple
+ // cgroup hierarchies. For example, it is not possible to mount both
+ // the cpu and cpuacct controllers against one hierarchy, and to mount
+ // the cpu controller alone against another hierarchy." - man cgroups(7)
+ //
+ // Is there a hierarchy available with all the controllers we want? If so,
+ // this mount is a view into the same hierarchy.
+ //
+ // Note: we're guaranteed to have at least one requested controller, since
+ // no explicit controller name implies all controllers.
+ if vfsfs := r.FindHierarchy(wantControllers); vfsfs != nil {
+ fs := vfsfs.Impl().(*filesystem)
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID)
+ fs.root.IncRef()
+ return vfsfs, fs.root.VFSDentry(), nil
+ }
+
+ // No existing hierarchy with the exactly controllers found. Make a new
+ // one. Note that it's possible this mount creation is unsatisfiable, if one
+ // or more of the requested controllers are already on existing
+ // hierarchies. We'll find out about such collisions when we try to register
+ // the new hierarchy later.
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.MaxCachedDentries = maxCachedDentries
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ var defaults map[string]int64
+ if opts.InternalData != nil {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults)
+ defaults = opts.InternalData.(*InternalData).DefaultControlValues
+ }
+
+ for _, ty := range wantControllers {
+ var c controller
+ switch ty {
+ case controllerCPU:
+ c = newCPUController(fs, defaults)
+ case controllerCPUAcct:
+ c = newCPUAcctController(fs)
+ case controllerCPUSet:
+ c = newCPUSetController(fs)
+ case controllerJob:
+ c = newJobController(fs)
+ case controllerMemory:
+ c = newMemoryController(fs, defaults)
+ default:
+ panic(fmt.Sprintf("Unreachable: unknown cgroup controller %q", ty))
+ }
+ fs.controllers = append(fs.controllers, c)
+ }
+
+ if len(defaults) != 0 {
+ // Internal data is always provided at sentry startup and unused values
+ // indicate a problem with the sandbox config. Fail fast.
+ panic(fmt.Sprintf("cgroupfs.FilesystemType.GetFilesystem: unknown internal mount data: %v", defaults))
+ }
+
+ // Controllers usually appear in alphabetical order when displayed. Sort it
+ // here now, so it never needs to be sorted elsewhere.
+ sort.Slice(fs.controllers, func(i, j int) bool { return fs.controllers[i].Type() < fs.controllers[j].Type() })
+ fs.kcontrollers = make([]kernel.CgroupController, 0, len(fs.controllers))
+ for _, c := range fs.controllers {
+ fs.kcontrollers = append(fs.kcontrollers, c)
+ }
+
+ root := fs.newCgroupInode(ctx, creds)
+ var rootD kernfs.Dentry
+ rootD.InitRoot(&fs.Filesystem, root)
+ fs.root = &rootD
+
+ // Register controllers. The registry may be modified concurrently, so if we
+ // get an error, we raced with someone else who registered the same
+ // controllers first.
+ hid, err := r.Register(fs.kcontrollers)
+ if err != nil {
+ ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err)
+ rootD.DecRef(ctx)
+ fs.VFSFilesystem().DecRef(ctx)
+ return nil, nil, syserror.EBUSY
+ }
+ fs.hierarchyID = hid
+
+ // Move all existing tasks to the root of the new hierarchy.
+ k.PopulateNewCgroupHierarchy(fs.rootCgroup())
+
+ return fs.VFSFilesystem(), rootD.VFSDentry(), nil
+}
+
+func (fs *filesystem) rootCgroup() kernel.Cgroup {
+ return kernel.Cgroup{
+ Dentry: fs.root,
+ CgroupImpl: fs.root.Inode().(kernel.CgroupImpl),
+ }
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ k := kernel.KernelFromContext(ctx)
+ r := k.CgroupRegistry()
+
+ if fs.hierarchyID != kernel.InvalidCgroupHierarchyID {
+ k.ReleaseCgroupHierarchy(fs.hierarchyID)
+ r.Unregister(fs.hierarchyID)
+ }
+
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ var cnames []string
+ for _, c := range fs.controllers {
+ cnames = append(cnames, string(c.Type()))
+ }
+ return strings.Join(cnames, ",")
+}
+
+// +stateify savable
+type implStatFS struct{}
+
+// StatFS implements kernfs.Inode.StatFS.
+func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) {
+ return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
+}
+
+// dir implements kernfs.Inode for a generic cgroup resource controller
+// directory. Specific controllers extend this to add their own functionality.
+//
+// +stateify savable
+type dir struct {
+ dirRefs
+ kernfs.InodeAlwaysValid
+ kernfs.InodeAttrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir.
+ kernfs.OrderedChildren
+ implStatFS
+
+ locks vfs.FileLocks
+}
+
+// Keep implements kernfs.Inode.Keep.
+func (*dir) Keep() bool {
+ return true
+}
+
+// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
+func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Open implements kernfs.Inode.Open.
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
+ SeekEnd: kernfs.SeekEndStaticEntries,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// DecRef implements kernfs.Inode.DecRef.
+func (d *dir) DecRef(ctx context.Context) {
+ d.dirRefs.DecRef(func() { d.Destroy(ctx) })
+}
+
+// StatFS implements kernfs.Inode.StatFS.
+func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) {
+ return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
+}
+
+// controllerFile represents a generic control file that appears within a cgroup
+// directory.
+//
+// +stateify savable
+type controllerFile struct {
+ kernfs.DynamicBytesFile
+}
+
+func (fs *filesystem) newControllerFile(ctx context.Context, creds *auth.Credentials, data vfs.DynamicBytesSource) kernfs.Inode {
+ f := &controllerFile{}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, readonlyFileMode)
+ return f
+}
+
+func (fs *filesystem) newControllerWritableFile(ctx context.Context, creds *auth.Credentials, data vfs.WritableDynamicBytesSource) kernfs.Inode {
+ f := &controllerFile{}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, writableFileMode)
+ return f
+}
+
+// staticControllerFile represents a generic control file that appears within a
+// cgroup directory which always returns the same data when read.
+// staticControllerFiles are not writable.
+//
+// +stateify savable
+type staticControllerFile struct {
+ kernfs.DynamicBytesFile
+ vfs.StaticData
+}
+
+// Note: We let the caller provide the mode so that static files may be used to
+// fake both readable and writable control files. However, static files are
+// effectively readonly, as attempting to write to them will return EIO
+// regardless of the mode.
+func (fs *filesystem) newStaticControllerFile(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, data string) kernfs.Inode {
+ f := &staticControllerFile{StaticData: vfs.StaticData{Data: data}}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), f, mode)
+ return f
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go
new file mode 100644
index 000000000..24d86a277
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// +stateify savable
+type cpuController struct {
+ controllerCommon
+
+ // CFS bandwidth control parameters, values in microseconds.
+ cfsPeriod int64
+ cfsQuota int64
+
+ // CPU shares, values should be (num core * 1024).
+ shares int64
+}
+
+var _ controller = (*cpuController)(nil)
+
+func newCPUController(fs *filesystem, defaults map[string]int64) *cpuController {
+ // Default values for controller parameters from Linux.
+ c := &cpuController{
+ cfsPeriod: 100000,
+ cfsQuota: -1,
+ shares: 1024,
+ }
+
+ if val, ok := defaults["cpu.cfs_period_us"]; ok {
+ c.cfsPeriod = val
+ delete(defaults, "cpu.cfs_period_us")
+ }
+ if val, ok := defaults["cpu.cfs_quota_us"]; ok {
+ c.cfsQuota = val
+ delete(defaults, "cpu.cfs_quota_us")
+ }
+ if val, ok := defaults["cpu.shares"]; ok {
+ c.shares = val
+ delete(defaults, "cpu.shares")
+ }
+
+ c.controllerCommon.init(controllerCPU, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpuController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["cpu.cfs_period_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsPeriod))
+ contents["cpu.cfs_quota_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsQuota))
+ contents["cpu.shares"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.shares))
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go
new file mode 100644
index 000000000..d4104a00e
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go
@@ -0,0 +1,114 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// +stateify savable
+type cpuacctController struct {
+ controllerCommon
+}
+
+var _ controller = (*cpuacctController)(nil)
+
+func newCPUAcctController(fs *filesystem) *cpuacctController {
+ c := &cpuacctController{}
+ c.controllerCommon.init(controllerCPUAcct, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpuacctController) AddControlFiles(ctx context.Context, creds *auth.Credentials, cg *cgroupInode, contents map[string]kernfs.Inode) {
+ cpuacctCG := &cpuacctCgroup{cg}
+ contents["cpuacct.stat"] = c.fs.newControllerFile(ctx, creds, &cpuacctStatData{cpuacctCG})
+ contents["cpuacct.usage"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageData{cpuacctCG})
+ contents["cpuacct.usage_user"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageUserData{cpuacctCG})
+ contents["cpuacct.usage_sys"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageSysData{cpuacctCG})
+}
+
+// +stateify savable
+type cpuacctCgroup struct {
+ *cgroupInode
+}
+
+func (c *cpuacctCgroup) collectCPUStats() usage.CPUStats {
+ var cs usage.CPUStats
+ c.fs.tasksMu.RLock()
+ // Note: This isn't very accurate, since the tasks are potentially
+ // still running as we accumulate their stats.
+ for t := range c.ts {
+ cs.Accumulate(t.CPUStats())
+ }
+ c.fs.tasksMu.RUnlock()
+ return cs
+}
+
+// +stateify savable
+type cpuacctStatData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "user %d\n", linux.ClockTFromDuration(cs.UserTime))
+ fmt.Fprintf(buf, "system %d\n", linux.ClockTFromDuration(cs.SysTime))
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()+cs.SysTime.Nanoseconds())
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageUserData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageUserData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds())
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageSysData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageSysData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.SysTime.Nanoseconds())
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
new file mode 100644
index 000000000..ac547f8e2
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
@@ -0,0 +1,39 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// +stateify savable
+type cpusetController struct {
+ controllerCommon
+}
+
+var _ controller = (*cpusetController)(nil)
+
+func newCPUSetController(fs *filesystem) *cpusetController {
+ c := &cpusetController{}
+ c.controllerCommon.init(controllerCPUSet, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ // This controller is currently intentionally empty.
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/job.go b/pkg/sentry/fsimpl/cgroupfs/job.go
new file mode 100644
index 000000000..48919c338
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/job.go
@@ -0,0 +1,64 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// +stateify savable
+type jobController struct {
+ controllerCommon
+ id int64
+}
+
+var _ controller = (*jobController)(nil)
+
+func newJobController(fs *filesystem) *jobController {
+ c := &jobController{}
+ c.controllerCommon.init(controllerJob, fs)
+ return c
+}
+
+func (c *jobController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["job.id"] = c.fs.newControllerWritableFile(ctx, creds, &jobIDData{c: c})
+}
+
+// +stateify savable
+type jobIDData struct {
+ c *jobController
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *jobIDData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d\n", d.c.id)
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *jobIDData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ val, n, err := parseInt64FromString(ctx, src, offset)
+ if err != nil {
+ return n, err
+ }
+ d.c.id = val
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go
new file mode 100644
index 000000000..485c98376
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/memory.go
@@ -0,0 +1,74 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// +stateify savable
+type memoryController struct {
+ controllerCommon
+
+ limitBytes int64
+}
+
+var _ controller = (*memoryController)(nil)
+
+func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController {
+ c := &memoryController{
+ // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which
+ // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact
+ // value isn't very important.
+ limitBytes: math.MaxInt64,
+ }
+ if val, ok := defaults["memory.limit_in_bytes"]; ok {
+ c.limitBytes = val
+ delete(defaults, "memory.limit_in_bytes")
+ }
+ c.controllerCommon.init(controllerMemory, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{})
+ contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes))
+}
+
+// +stateify savable
+type memoryUsageInBytesData struct{}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *memoryUsageInBytesData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/183151557): This is a giant hack, we're using system-wide
+ // accounting since we know there is only one cgroup.
+ k := kernel.KernelFromContext(ctx)
+ mf := k.MemoryFile()
+ mf.UpdateUsage()
+ _, totalBytes := usage.MemoryAccounting.Copy()
+
+ fmt.Fprintf(buf, "%d\n", totalBytes)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/eventfd/BUILD b/pkg/sentry/fsimpl/eventfd/BUILD
index bcb01bb08..c09fdc7f9 100644
--- a/pkg/sentry/fsimpl/eventfd/BUILD
+++ b/pkg/sentry/fsimpl/eventfd/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fdnotifier",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/sentry/vfs",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index 30bd05357..4f79cfcb7 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -188,7 +189,7 @@ func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequenc
efd.queue.Notify(waiter.WritableEvents)
var buf [8]byte
- usermem.ByteOrder.PutUint64(buf[:], val)
+ hostarch.ByteOrder.PutUint64(buf[:], val)
_, err := dst.CopyOut(ctx, buf[:])
return err
}
@@ -196,7 +197,7 @@ func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequenc
// Preconditions: Must be called with efd.mu locked.
func (efd *EventFileDescription) hostWriteLocked(val uint64) error {
var buf [8]byte
- usermem.ByteOrder.PutUint64(buf[:], val)
+ hostarch.ByteOrder.PutUint64(buf[:], val)
_, err := unix.Write(efd.hostfd, buf[:])
if err == unix.EWOULDBLOCK {
return syserror.ErrWouldBlock
@@ -209,7 +210,7 @@ func (efd *EventFileDescription) write(ctx context.Context, src usermem.IOSequen
if _, err := src.CopyIn(ctx, buf[:]); err != nil {
return err
}
- val := usermem.ByteOrder.Uint64(buf[:])
+ val := hostarch.ByteOrder.Uint64(buf[:])
return efd.Signal(val)
}
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 155c0f56d..3a4777fbe 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -46,6 +46,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
@@ -75,6 +76,7 @@ go_test(
library = ":fuse",
deps = [
"//pkg/abi/linux",
+ "//pkg/hostarch",
"//pkg/marshal",
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 23ce91849..66ea889f9 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -20,11 +20,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// ReadInPages sends FUSE_READ requests for the size after round it up to
@@ -43,10 +43,10 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
}
// Round up to a multiple of page size.
- readSize, _ := usermem.PageRoundUp(uint64(size))
+ readSize, _ := hostarch.PageRoundUp(uint64(size))
// One request cannnot exceed either maxRead or maxPages.
- maxPages := fs.conn.maxRead >> usermem.PageShift
+ maxPages := fs.conn.maxRead >> hostarch.PageShift
if maxPages > uint32(fs.conn.maxPages) {
maxPages = uint32(fs.conn.maxPages)
}
@@ -54,9 +54,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
var outs [][]byte
var sizeRead uint32
- // readSize is a multiple of usermem.PageSize.
+ // readSize is a multiple of hostarch.PageSize.
// Always request bytes as a multiple of pages.
- pagesRead, pagesToRead := uint32(0), uint32(readSize>>usermem.PageShift)
+ pagesRead, pagesToRead := uint32(0), uint32(readSize>>hostarch.PageShift)
// Reuse the same struct for unmarshalling to avoid unnecessary memory allocation.
in := linux.FUSEReadIn{
@@ -76,8 +76,8 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
pagesCanRead = maxPages
}
- in.Offset = off + (uint64(pagesRead) << usermem.PageShift)
- in.Size = pagesCanRead << usermem.PageShift
+ in.Offset = off + (uint64(pagesRead) << hostarch.PageShift)
+ in.Size = pagesCanRead << hostarch.PageShift
// TODO(gvisor.dev/issue/3247): support async read.
@@ -159,7 +159,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
}
// One request cannnot exceed either maxWrite or maxPages.
- maxWrite := uint32(fs.conn.maxPages) << usermem.PageShift
+ maxWrite := uint32(fs.conn.maxPages) << hostarch.PageShift
if maxWrite > fs.conn.maxWrite {
maxWrite = fs.conn.maxWrite
}
@@ -188,8 +188,8 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
// Limit the write size to one page.
// Note that the bigWrites flag is obsolete,
// latest libfuse always sets it on.
- if !fs.conn.bigWrites && toWrite > usermem.PageSize {
- toWrite = usermem.PageSize
+ if !fs.conn.bigWrites && toWrite > hostarch.PageSize {
+ toWrite = hostarch.PageSize
}
// Limit the write size to maxWrite.
diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go
index 10fb9d7d2..8a72489fa 100644
--- a/pkg/sentry/fsimpl/fuse/request_response.go
+++ b/pkg/sentry/fsimpl/fuse/request_response.go
@@ -19,10 +19,10 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/usermem"
)
// fuseInitRes is a variable-length wrapper of linux.FUSEInitOut. The FUSE
@@ -45,29 +45,29 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) {
out := &r.initOut
// Introduced before FUSE kernel version 7.13.
- out.Major = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ out.Major = uint32(hostarch.ByteOrder.Uint32(src[:4]))
src = src[4:]
- out.Minor = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ out.Minor = uint32(hostarch.ByteOrder.Uint32(src[:4]))
src = src[4:]
- out.MaxReadahead = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ out.MaxReadahead = uint32(hostarch.ByteOrder.Uint32(src[:4]))
src = src[4:]
- out.Flags = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ out.Flags = uint32(hostarch.ByteOrder.Uint32(src[:4]))
src = src[4:]
- out.MaxBackground = uint16(usermem.ByteOrder.Uint16(src[:2]))
+ out.MaxBackground = uint16(hostarch.ByteOrder.Uint16(src[:2]))
src = src[2:]
- out.CongestionThreshold = uint16(usermem.ByteOrder.Uint16(src[:2]))
+ out.CongestionThreshold = uint16(hostarch.ByteOrder.Uint16(src[:2]))
src = src[2:]
- out.MaxWrite = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ out.MaxWrite = uint32(hostarch.ByteOrder.Uint32(src[:4]))
src = src[4:]
// Introduced in FUSE kernel version 7.23.
if len(src) >= 4 {
- out.TimeGran = uint32(usermem.ByteOrder.Uint32(src[:4]))
+ out.TimeGran = uint32(hostarch.ByteOrder.Uint32(src[:4]))
src = src[4:]
}
// Introduced in FUSE kernel version 7.28.
if len(src) >= 2 {
- out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2]))
+ out.MaxPages = uint16(hostarch.ByteOrder.Uint16(src[:2]))
src = src[2:]
}
_ = src // Remove unused warning.
diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go
index 2c0cc0f4e..b0bab0066 100644
--- a/pkg/sentry/fsimpl/fuse/utils_test.go
+++ b/pkg/sentry/fsimpl/fuse/utils_test.go
@@ -24,7 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
func setup(t *testing.T) *testutil.System {
@@ -82,12 +83,12 @@ func (t *testPayload) SizeBytes() int {
// MarshalBytes implements marshal.Marshallable.MarshalBytes.
func (t *testPayload) MarshalBytes(dst []byte) {
- usermem.ByteOrder.PutUint32(dst[:4], t.data)
+ hostarch.ByteOrder.PutUint32(dst[:4], t.data)
}
// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
func (t *testPayload) UnmarshalBytes(src []byte) {
- *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
+ *t = testPayload{data: hostarch.ByteOrder.Uint32(src[:4])}
}
// Packed implements marshal.Marshallable.Packed.
@@ -106,17 +107,17 @@ func (t *testPayload) UnmarshalUnsafe(src []byte) {
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (t *testPayload) CopyOutN(task marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
+func (t *testPayload) CopyOutN(task marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {
panic("not implemented")
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (t *testPayload) CopyOut(task marshal.CopyContext, addr usermem.Addr) (int, error) {
+func (t *testPayload) CopyOut(task marshal.CopyContext, addr hostarch.Addr) (int, error) {
panic("not implemented")
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (t *testPayload) CopyIn(task marshal.CopyContext, addr usermem.Addr) (int, error) {
+func (t *testPayload) CopyIn(task marshal.CopyContext, addr hostarch.Addr) (int, error) {
panic("not implemented")
}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 807b6ed1f..6d5258a9b 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -51,6 +51,7 @@ go_library(
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/p9",
"//pkg/refs",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 9da01cba3..177e42649 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
func (d *dentry) isDir() bool {
@@ -98,7 +98,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
mode: uint32(opts.mode),
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
- blockSize: usermem.PageSize, // arbitrary
+ blockSize: hostarch.PageSize, // arbitrary
atime: now,
mtime: now,
ctime: now,
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 43c3c5a2d..4b5621043 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -141,21 +141,8 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **
return
}
ds := **dsp
- // Only go through calling dentry.checkCachingLocked() (which requires
- // re-locking renameMu) if we actually have any dentries with zero refs.
- checkAny := false
- for i := range ds {
- if atomic.LoadInt64(&ds[i].refs) == 0 {
- checkAny = true
- break
- }
- }
- if checkAny {
- fs.renameMu.Lock()
- for _, d := range ds {
- d.checkCachingLocked(ctx)
- }
- fs.renameMu.Unlock()
+ for _, d := range ds {
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
putDentrySlice(*dsp)
}
@@ -166,7 +153,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]
return
}
for _, d := range **ds {
- d.checkCachingLocked(ctx)
+ d.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
fs.renameMu.Unlock()
putDentrySlice(*ds)
@@ -339,8 +326,10 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
}
parent.cacheNewChildLocked(child, name)
// For now, child has 0 references, so our caller should call
- // child.checkCachingLocked().
+ // child.checkCachingLocked(). parent gained a ref so we should also call
+ // parent.checkCachingLocked() so it can be removed from the cache if needed.
*ds = appendDentry(*ds, child)
+ *ds = appendDentry(*ds, parent)
return child, nil
}
@@ -723,6 +712,8 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
}
}
d.IncRef()
+ // Call d.checkCachingLocked() so it can be removed from the cache if needed.
+ ds = appendDentry(ds, d)
return &d.vfsd, nil
}
@@ -744,6 +735,8 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
return nil, err
}
d.IncRef()
+ // Call d.checkCachingLocked() so it can be removed from the cache if needed.
+ ds = appendDentry(ds, d)
return &d.vfsd, nil
}
@@ -782,7 +775,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
creds := rp.Credentials()
- return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
// If the parent is a setgid directory, use the parent's GID
// rather than the caller's and enable setgid.
kgid := creds.EffectiveKGID
@@ -802,6 +795,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kuid: creds.EffectiveKUID,
kgid: creds.EffectiveKGID,
})
+ *ds = appendDentry(*ds, parent)
}
if fs.opts.interop != InteropModeShared {
parent.incLinks()
@@ -855,6 +849,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kgid: creds.EffectiveKGID,
endpoint: opts.Endpoint,
})
+ *ds = appendDentry(*ds, parent)
return nil
case linux.S_IFIFO:
parent.createSyntheticChildLocked(&createSyntheticOpts{
@@ -864,6 +859,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kgid: creds.EffectiveKGID,
pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
+ *ds = appendDentry(*ds, parent)
return nil
}
// Retain error from gofer if synthetic file cannot be created internally.
@@ -912,6 +908,8 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
start.IncRef()
defer start.DecRef(ctx)
unlock()
+ // start is intentionally not added to ds (which would remove it from the
+ // cache) because doing so regresses performance in practice.
return start.open(ctx, rp, &opts)
}
@@ -965,6 +963,8 @@ afterTrailingSymlink:
child.IncRef()
defer child.DecRef(ctx)
unlock()
+ // child is intentionally not added to ds (which would remove it from the
+ // cache) because doing so regresses performance in practice.
return child.open(ctx, rp, &opts)
}
@@ -1212,6 +1212,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
// Insert the dentry into the tree.
d.cacheNewChildLocked(child, name)
+ *ds = appendDentry(*ds, d)
if d.cachedMetadataAuthoritative() {
d.touchCMtime()
d.dirents = nil
@@ -1403,6 +1404,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.decRefNoCaching()
ds = appendDentry(ds, oldParent)
newParent.IncRef()
+ ds = appendDentry(ds, newParent)
if renamed.isSynthetic() {
oldParent.syntheticChildren--
newParent.syntheticChildren++
@@ -1546,6 +1548,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
if d.isSocket() {
if !d.isSynthetic() {
d.IncRef()
+ ds = appendDentry(ds, d)
return &endpoint{
dentry: d,
path: opts.Addr,
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 692da02c1..fb42c5f62 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -18,15 +18,17 @@
// Lock order:
// regularFileFD/directoryFD.mu
// filesystem.renameMu
-// dentry.dirMu
-// filesystem.syncMu
-// dentry.metadataMu
-// *** "memmap.Mappable locks" below this point
-// dentry.mapsMu
-// *** "memmap.Mappable locks taken by Translate" below this point
-// dentry.handleMu
-// dentry.dataMu
-// filesystem.inoMu
+// dentry.cachingMu
+// filesystem.cacheMu
+// dentry.dirMu
+// filesystem.syncMu
+// dentry.metadataMu
+// *** "memmap.Mappable locks" below this point
+// dentry.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// dentry.handleMu
+// dentry.dataMu
+// filesystem.inoMu
// specialFileFD.mu
// specialFileFD.bufMu
//
@@ -44,6 +46,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
@@ -60,7 +63,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/unet"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Name is the default filesystem name.
@@ -140,7 +142,8 @@ type filesystem struct {
// cachedDentries contains all dentries with 0 references. (Due to race
// conditions, it may also contain dentries with non-zero references.)
// cachedDentriesLen is the number of dentries in cachedDentries. These fields
- // are protected by renameMu.
+ // are protected by cacheMu.
+ cacheMu sync.Mutex `state:"nosave"`
cachedDentries dentryList
cachedDentriesLen uint64
@@ -620,11 +623,11 @@ func (fs *filesystem) Release(ctx context.Context) {
// the reference count on every synthetic dentry. Synthetic dentries have one
// reference for existence that should be dropped during filesystem.Release.
//
-// Precondition: d.fs.renameMu is locked.
+// Precondition: d.fs.renameMu is locked for writing.
func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
if d.isSynthetic() {
d.decRefNoCaching()
- d.checkCachingLocked(ctx)
+ d.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
if d.isDir() {
var children []*dentry
@@ -682,9 +685,13 @@ type dentry struct {
// deleted. deleted is accessed using atomic memory operations.
deleted uint32
+ // cachingMu is used to synchronize concurrent dentry caching attempts on
+ // this dentry.
+ cachingMu sync.Mutex `state:"nosave"`
+
// If cached is true, dentryEntry links dentry into
// filesystem.cachedDentries. cached and dentryEntry are protected by
- // filesystem.renameMu.
+ // cachingMu.
cached bool
dentryEntry
@@ -872,7 +879,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
mode: uint32(attr.Mode),
uid: uint32(fs.opts.dfltuid),
gid: uint32(fs.opts.dfltgid),
- blockSize: usermem.PageSize,
+ blockSize: hostarch.PageSize,
readFD: -1,
writeFD: -1,
mmapFD: -1,
@@ -980,36 +987,63 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
// Preconditions: !d.isSynthetic().
+// Preconditions: d.metadataMu is locked.
+func (d *dentry) refreshSizeLocked(ctx context.Context) error {
+ d.handleMu.RLock()
+
+ if d.writeFD < 0 {
+ d.handleMu.RUnlock()
+ // Ask the gofer if we don't have a host FD.
+ return d.updateFromGetattrLocked(ctx)
+ }
+
+ var stat unix.Statx_t
+ err := unix.Statx(int(d.writeFD), "", unix.AT_EMPTY_PATH, unix.STATX_SIZE, &stat)
+ d.handleMu.RUnlock() // must be released before updateSizeLocked()
+ if err != nil {
+ return err
+ }
+ d.updateSizeLocked(stat.Size)
+ return nil
+}
+
+// Preconditions: !d.isSynthetic().
func (d *dentry) updateFromGetattr(ctx context.Context) error {
- // Use d.readFile or d.writeFile, which represent 9P fids that have been
+ // d.metadataMu must be locked *before* we getAttr so that we do not end up
+ // updating stale attributes in d.updateFromP9AttrsLocked().
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ return d.updateFromGetattrLocked(ctx)
+}
+
+// Preconditions:
+// * !d.isSynthetic().
+// * d.metadataMu is locked.
+func (d *dentry) updateFromGetattrLocked(ctx context.Context) error {
+ // Use d.readFile or d.writeFile, which represent 9P FIDs that have been
// opened, in preference to d.file, which represents a 9P fid that has not.
// This may be significantly more efficient in some implementations. Prefer
// d.writeFile over d.readFile since some filesystem implementations may
// update a writable handle's metadata after writes to that handle, without
// making metadata updates immediately visible to read-only handles
// representing the same file.
- var (
- file p9file
- handleMuRLocked bool
- )
- // d.metadataMu must be locked *before* we getAttr so that we do not end up
- // updating stale attributes in d.updateFromP9AttrsLocked().
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
d.handleMu.RLock()
- if !d.writeFile.isNil() {
+ handleMuRLocked := true
+ var file p9file
+ switch {
+ case !d.writeFile.isNil():
file = d.writeFile
- handleMuRLocked = true
- } else if !d.readFile.isNil() {
+ case !d.readFile.isNil():
file = d.readFile
- handleMuRLocked = true
- } else {
+ default:
file = d.file
d.handleMu.RUnlock()
+ handleMuRLocked = false
}
+
_, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
if handleMuRLocked {
- d.handleMu.RUnlock()
+ d.handleMu.RUnlock() // must be released before updateFromP9AttrsLocked()
}
if err != nil {
return err
@@ -1104,24 +1138,27 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
defer d.metadataMu.Unlock()
// As with Linux, if the UID, GID, or file size is changing, we have to
- // clear permission bits. Note that when set, clearSGID causes
- // permissions to be updated, but does not modify stat.Mask, as
- // modification would cause an extra inotify flag to be set.
- clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) ||
- stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) ||
+ // clear permission bits. Note that when set, clearSGID may cause
+ // permissions to be updated.
+ clearSGID := (stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid)) ||
+ (stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid)) ||
stat.Mask&linux.STATX_SIZE != 0
if clearSGID {
if stat.Mask&linux.STATX_MODE != 0 {
stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode)))
} else {
- stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode)))
+ oldMode := atomic.LoadUint32(&d.mode)
+ if updatedMode := vfs.ClearSUIDAndSGID(oldMode); updatedMode != oldMode {
+ stat.Mode = uint16(updatedMode)
+ stat.Mask |= linux.STATX_MODE
+ }
}
}
if !d.isSynthetic() {
if stat.Mask != 0 {
if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID,
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
GID: stat.Mask&linux.STATX_GID != 0,
Size: stat.Mask&linux.STATX_SIZE != 0,
@@ -1156,7 +1193,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 || clearSGID {
+ if stat.Mask&linux.STATX_MODE != 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
if stat.Mask&linux.STATX_UID != 0 {
@@ -1217,8 +1254,8 @@ func (d *dentry) updateSizeLocked(newSize uint64) {
// so we can't race with Write or another truncate.)
d.dataMu.Unlock()
if d.size < oldSize {
- oldpgend, _ := usermem.PageRoundUp(oldSize)
- newpgend, _ := usermem.PageRoundUp(d.size)
+ oldpgend, _ := hostarch.PageRoundUp(oldSize)
+ newpgend, _ := hostarch.PageRoundUp(d.size)
if oldpgend != newpgend {
d.mapsMu.Lock()
d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
@@ -1312,9 +1349,7 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
if d.decRefNoCaching() == 0 {
- d.fs.renameMu.Lock()
- d.checkCachingLocked(ctx)
- d.fs.renameMu.Unlock()
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
}
@@ -1374,15 +1409,16 @@ func (d *dentry) Watches() *vfs.Watches {
//
// If no watches are left on this dentry and it has no references, cache it.
func (d *dentry) OnZeroWatches(ctx context.Context) {
- if atomic.LoadInt64(&d.refs) == 0 {
- d.fs.renameMu.Lock()
- d.checkCachingLocked(ctx)
- d.fs.renameMu.Unlock()
- }
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
-// checkCachingLocked should be called after d's reference count becomes 0 or it
-// becomes disowned.
+// checkCachingLocked should be called after d's reference count becomes 0 or
+// it becomes disowned.
+//
+// For performance, checkCachingLocked can also be called after d's reference
+// count becomes non-zero, so that d can be removed from the LRU cache. This
+// may help in reducing the size of the cache and hence reduce evictions. Note
+// that this is not necessary for correctness.
//
// It may be called on a destroyed dentry. For example,
// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
@@ -1390,33 +1426,46 @@ func (d *dentry) OnZeroWatches(ctx context.Context) {
// operation. One of the calls may destroy the dentry, so subsequent calls will
// do nothing.
//
-// Preconditions: d.fs.renameMu must be locked for writing; it may be
-// temporarily unlocked.
-func (d *dentry) checkCachingLocked(ctx context.Context) {
- // Dentries with a non-zero reference count must be retained. (The only way
- // to obtain a reference on a dentry with zero references is via path
- // resolution, which requires renameMu, so if d.refs is zero then it will
- // remain zero while we hold renameMu for writing.)
+// Preconditions: d.fs.renameMu must be locked for writing if
+// renameMuWriteLocked is true; it may be temporarily unlocked.
+func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) {
+ d.cachingMu.Lock()
refs := atomic.LoadInt64(&d.refs)
if refs == -1 {
// Dentry has already been destroyed.
+ d.cachingMu.Unlock()
return
}
if refs > 0 {
- // This isn't strictly necessary (fs.cachedDentries is permitted to
- // contain dentries with non-zero refs, which are skipped by
- // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but
- // since we are already holding fs.renameMu for writing we may as well.
+ // fs.cachedDentries is permitted to contain dentries with non-zero refs,
+ // which are skipped by fs.evictCachedDentryLocked() upon reaching the end
+ // of the LRU. But it is still beneficial to remove d from the cache as we
+ // are already holding d.cachingMu. Keeping a cleaner cache also reduces
+ // the number of evictions (which is expensive as it acquires fs.renameMu).
d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
return
}
// Deleted and invalidated dentries with zero references are no longer
// reachable by path resolution and should be dropped immediately.
if d.vfsd.IsDead() {
+ d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu for writing as needed by d.destroyLocked().
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ // Now that renameMu is locked for writing, no more refs can be taken on
+ // d because path resolution requires renameMu for reading at least.
+ if atomic.LoadInt64(&d.refs) != 0 {
+ // Destroy d only if its ref is still 0. If not, either someone took a
+ // ref on it or it got destroyed before fs.renameMu could be acquired.
+ return
+ }
+ }
if d.isDeleted() {
d.watches.HandleDeletion(ctx)
}
- d.removeFromCacheLocked()
d.destroyLocked(ctx)
return
}
@@ -1426,24 +1475,36 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
// d.watches cannot concurrently transition from zero to non-zero, because
// adding a watch requires holding a reference on d.
if d.watches.Size() > 0 {
- // As in the refs > 0 case, this is not strictly necessary.
+ // As in the refs > 0 case, removing d is beneficial.
d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
return
}
if atomic.LoadInt32(&d.fs.released) != 0 {
+ d.cachingMu.Unlock()
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as
+ // needed by d.destroyLocked() later.
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
if d.parent != nil {
d.parent.dirMu.Lock()
delete(d.parent.children, d.name)
d.parent.dirMu.Unlock()
}
d.destroyLocked(ctx)
+ return
}
+ d.fs.cacheMu.Lock()
// If d is already cached, just move it to the front of the LRU.
if d.cached {
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentries.PushFront(d)
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
return
}
// Cache the dentry, then evict the least recently used cached dentry if
@@ -1451,18 +1512,28 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
d.fs.cachedDentries.PushFront(d)
d.fs.cachedDentriesLen++
d.cached = true
- if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
+ shouldEvict := d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
+
+ if shouldEvict {
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu for writing as needed by
+ // d.evictCachedDentryLocked().
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
d.fs.evictCachedDentryLocked(ctx)
- // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
- // back down to fs.opts.maxCachedDentries, so we don't loop.
}
}
-// Preconditions: d.fs.renameMu must be locked for writing.
+// Preconditions: d.cachingMu must be locked.
func (d *dentry) removeFromCacheLocked() {
if d.cached {
+ d.fs.cacheMu.Lock()
d.fs.cachedDentries.Remove(d)
d.fs.cachedDentriesLen--
+ d.fs.cacheMu.Unlock()
d.cached = false
}
}
@@ -1477,28 +1548,43 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
// Preconditions:
// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
-// * fs.cachedDentriesLen != 0.
func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
+ fs.cacheMu.Lock()
victim := fs.cachedDentries.Back()
+ fs.cacheMu.Unlock()
+ if victim == nil {
+ // fs.cachedDentries may have become empty between when it was checked and
+ // when we locked fs.cacheMu.
+ return
+ }
+
+ victim.cachingMu.Lock()
victim.removeFromCacheLocked()
// victim.refs or victim.watches.Size() may have become non-zero from an
// earlier path resolution since it was inserted into fs.cachedDentries.
- if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 {
- if victim.parent != nil {
- victim.parent.dirMu.Lock()
- if !victim.vfsd.IsDead() {
- // Note that victim can't be a mount point (in any mount
- // namespace), since VFS holds references on mount points.
- fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
- delete(victim.parent.children, victim.name)
- // We're only deleting the dentry, not the file it
- // represents, so we don't need to update
- // victimParent.dirents etc.
- }
- victim.parent.dirMu.Unlock()
+ if atomic.LoadInt64(&victim.refs) != 0 || victim.watches.Size() != 0 {
+ victim.cachingMu.Unlock()
+ return
+ }
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ if !victim.vfsd.IsDead() {
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
}
- victim.destroyLocked(ctx)
+ victim.parent.dirMu.Unlock()
}
+ // Safe to unlock cachingMu now that victim.vfsd.IsDead(). Henceforth any
+ // concurrent caching attempts on victim will attempt to destroy it and so
+ // will try to acquire fs.renameMu (which we have already acquired). Hence,
+ // fs.renameMu will synchronize the destroy attempts.
+ victim.cachingMu.Unlock()
+ victim.destroyLocked(ctx)
}
// destroyLocked destroys the dentry.
@@ -1584,7 +1670,7 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// Drop the reference held by d on its parent without recursively locking
// d.fs.renameMu.
if d.parent != nil && d.parent.decRefNoCaching() == 0 {
- d.parent.checkCachingLocked(ctx)
+ d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
refsvfs2.Unregister(d)
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index 76f08e252..806392d50 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -55,7 +55,7 @@ func TestDestroyIdempotent(t *testing.T) {
fs.renameMu.Lock()
defer fs.renameMu.Unlock()
- child.checkCachingLocked(ctx)
+ child.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
if got := atomic.LoadInt64(&child.refs); got != -1 {
t.Fatalf("child.refs=%d, want: -1", got)
}
@@ -63,6 +63,6 @@ func TestDestroyIdempotent(t *testing.T) {
if got := atomic.LoadInt64(&parent.refs); got != -1 {
t.Fatalf("parent.refs=%d, want: -1", got)
}
- child.checkCachingLocked(ctx)
- child.checkCachingLocked(ctx)
+ child.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
+ child.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 4f1ad0c88..f0e7bbaf7 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
@@ -203,18 +204,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
d := fd.dentry()
+
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
// If the fd was opened with O_APPEND, make sure the file size is updated.
// There is a possible race here if size is modified externally after
// metadata cache is updated.
if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
- if err := d.updateFromGetattr(ctx); err != nil {
+ if err := d.refreshSizeLocked(ctx); err != nil {
return 0, offset, err
}
}
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
-
// Set offset to file size if the fd was opened with O_APPEND.
if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
// Holding d.metadataMu is sufficient for reading d.size.
@@ -291,8 +293,8 @@ func (fd *regularFileFD) writeCache(ctx context.Context, d *dentry, offset int64
}
// Remove touched pages from the cache.
- pgstart := usermem.PageRoundDown(uint64(offset))
- pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes()))
+ pgstart := hostarch.PageRoundDown(uint64(offset))
+ pgend, ok := hostarch.PageRoundUp(uint64(offset + src.NumBytes()))
if !ok {
return syserror.EINVAL
}
@@ -408,7 +410,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error)
switch {
case seg.Ok():
// Get internal mappings from the cache.
- ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read)
if err != nil {
dataMuUnlock()
rw.d.handleMu.RUnlock()
@@ -434,9 +436,9 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error)
if fillCache {
// Read into the cache, then re-enter the loop to read from the
// cache.
- gapEnd, _ := usermem.PageRoundUp(gapMR.End)
+ gapEnd, _ := hostarch.PageRoundUp(gapMR.End)
reqMR := memmap.MappableRange{
- Start: usermem.PageRoundDown(gapMR.Start),
+ Start: hostarch.PageRoundDown(gapMR.Start),
End: gapEnd,
}
optMR := gap.Range()
@@ -527,7 +529,7 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro
case seg.Ok():
// Get internal mappings from the cache.
segMR := seg.Range().Intersect(mr)
- ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write)
+ ims, err := mf.MapInternal(seg.FileRangeOf(segMR), hostarch.Write)
if err != nil {
retErr = err
goto exitLoop
@@ -700,6 +702,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
}
// After this point, d may be used as a memmap.Mappable.
d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
+ opts.SentryOwnedContent = d.fs.opts.forcePageCache
return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
}
@@ -714,7 +717,7 @@ func (d *dentry) mayCachePages() bool {
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
d.mapsMu.Lock()
mapped := d.mappings.AddMapping(ms, ar, offset, writable)
// Do this unconditionally since whether we have a host FD can change
@@ -735,7 +738,7 @@ func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar user
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
d.mapsMu.Lock()
unmapped := d.mappings.RemoveMapping(ms, ar, offset, writable)
for _, r := range unmapped {
@@ -759,12 +762,12 @@ func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar u
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
return d.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
d.handleMu.RLock()
if d.mmapFD >= 0 && !d.fs.opts.forcePageCache {
d.handleMu.RUnlock()
@@ -777,7 +780,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab
Source: mr,
File: &d.pf,
Offset: mr.Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
},
}, nil
}
@@ -786,7 +789,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab
// Constrain translations to d.size (rounded up) to prevent translation to
// pages that may be concurrently truncated.
- pgend, _ := usermem.PageRoundUp(d.size)
+ pgend, _ := hostarch.PageRoundUp(d.size)
var beyondEOF bool
if required.End > pgend {
if required.Start >= pgend {
@@ -811,7 +814,7 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab
segMR := seg.Range().Intersect(optional)
// TODO(jamieliu): Make Translations writable even if writability is
// not required if already kept-dirty by another writable translation.
- perms := usermem.AccessType{
+ perms := hostarch.AccessType{
Read: true,
Execute: true,
}
@@ -954,7 +957,7 @@ func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) {
}
// MapInternal implements memmap.File.MapInternal.
-func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
return d.hostFileMapper.MapInternal(fr, int(d.mmapFD), at.Write)
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
index c90071e4e..83e841a51 100644
--- a/pkg/sentry/fsimpl/gofer/save_restore.go
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -22,12 +22,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
type saveRestoreContextID int
@@ -85,7 +85,7 @@ func (fs *filesystem) PrepareSave(ctx context.Context) error {
func (fd *specialFileFD) savePipeData(ctx context.Context) error {
fd.bufMu.Lock()
defer fd.bufMu.Unlock()
- var buf [usermem.PageSize]byte
+ var buf [hostarch.PageSize]byte
for {
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0))
if n != 0 {
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 4ae9d6d5e..b94dfeb7f 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -47,6 +47,7 @@ go_library(
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/iovec",
"//pkg/log",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index b9cce4181..a81f550b1 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
@@ -431,8 +432,8 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
oldSize := uint64(hostStat.Size)
if s.Size < oldSize {
- oldpgend, _ := usermem.PageRoundUp(oldSize)
- newpgend, _ := usermem.PageRoundUp(s.Size)
+ oldpgend, _ := hostarch.PageRoundUp(oldSize)
+ newpgend, _ := hostarch.PageRoundUp(s.Size)
if oldpgend != newpgend {
i.CachedMappable.InvalidateRange(memmap.MappableRange{newpgend, oldpgend})
}
@@ -459,6 +460,9 @@ func (i *inode) DecRef(ctx context.Context) {
if err := unix.Close(i.hostFD); err != nil {
log.Warningf("failed to close host fd %d: %v", i.hostFD, err)
}
+ // We can't rely on fdnotifier when closing the fd, because the event may race
+ // with fdnotifier.RemoveFD. Instead, notify the queue explicitly.
+ i.queue.Notify(waiter.EventHUp | waiter.ReadableEvents | waiter.WritableEvents)
})
}
diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go
index 5688bddc8..c502d8e99 100644
--- a/pkg/sentry/fsimpl/host/save_restore.go
+++ b/pkg/sentry/fsimpl/host/save_restore.go
@@ -21,9 +21,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
- "gvisor.dev/gvisor/pkg/usermem"
)
// beforeSave is invoked by stateify.
@@ -38,7 +38,7 @@ func (i *inode) beforeSave() {
// EBADF from the read.
i.bufMu.Lock()
defer i.bufMu.Unlock()
- var buf [usermem.PageSize]byte
+ var buf [hostarch.PageSize]byte
for {
n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */)
if n != 0 {
@@ -68,3 +68,10 @@ func (i *inode) afterLoad() {
}
}
}
+
+// afterLoad is invoked by stateify.
+func (c *ConnectedEndpoint) afterLoad() {
+ if err := c.initFromOptions(); err != nil {
+ panic(fmt.Sprintf("initFromOptions failed: %v", err))
+ }
+}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 60e237ac7..ca85f5601 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -39,7 +39,7 @@ import (
func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) {
// Set up an external transport.Endpoint using the host fd.
addr := fmt.Sprintf("hostfd:[%d]", hostFD)
- e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */)
+ e, err := NewConnectedEndpoint(hostFD, addr)
if err != nil {
return nil, err.ToError()
}
@@ -86,7 +86,10 @@ type ConnectedEndpoint struct {
// for restoring them.
func (c *ConnectedEndpoint) init() *syserr.Error {
c.InitRefs()
+ return c.initFromOptions()
+}
+func (c *ConnectedEndpoint) initFromOptions() *syserr.Error {
family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
if err != nil {
return syserr.FromError(err)
@@ -123,7 +126,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
// The caller is responsible for calling Init(). Additionaly, Release needs to
// be called twice because ConnectedEndpoint is both a transport.Receiver and
// transport.ConnectedEndpoint.
-func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) {
+func NewConnectedEndpoint(hostFD int, addr string) (*ConnectedEndpoint, *syserr.Error) {
e := ConnectedEndpoint{
fd: hostFD,
addr: addr,
@@ -330,8 +333,16 @@ func (c *ConnectedEndpoint) CloseUnread() {}
// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
- // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
- // sockets.
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix
+ // domain sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
+// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize.
+func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_RCVBUF for host backed unix
+ // domain sockets. Receive buffer does not have any effect for unix
+ // sockets and we claim to be the same as send buffer.
return atomic.LoadInt64(&c.sndbuf)
}
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 6dbc7e34d..b7d13cced 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -105,6 +105,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 65054b0ea..84b1c3745 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -25,8 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// DynamicBytesFile implements kernfs.Inode and represents a read-only
-// file whose contents are backed by a vfs.DynamicBytesSource.
+// DynamicBytesFile implements kernfs.Inode and represents a read-only file
+// whose contents are backed by a vfs.DynamicBytesSource. If data additionally
+// implements vfs.WritableDynamicBytesSource, the file also supports dispatching
+// writes to the implementer, but note that this will not update the source data.
//
// Must be instantiated with NewDynamicBytesFile or initialized with Init
// before first use.
@@ -40,7 +42,9 @@ type DynamicBytesFile struct {
InodeNotSymlink
locks vfs.FileLocks
- data vfs.DynamicBytesSource
+ // data can additionally implement vfs.WritableDynamicBytesSource to support
+ // writes.
+ data vfs.DynamicBytesSource
}
var _ Inode = (*DynamicBytesFile)(nil)
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 6b890a39c..3d0866ecf 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -20,12 +20,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// InodeNoopRefCount partially implements the Inode interface, specifically the
@@ -206,7 +206,7 @@ func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor
atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID))
atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID))
atomic.StoreUint32(&a.nlink, nlink)
- atomic.StoreUint32(&a.blockSize, usermem.PageSize)
+ atomic.StoreUint32(&a.blockSize, hostarch.PageSize)
now := ktime.NowFromContext(ctx).Nanoseconds()
atomic.StoreInt64(&a.atime, now)
atomic.StoreInt64(&a.mtime, now)
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 565d723f0..16486eeae 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -61,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -508,6 +509,15 @@ func (d *Dentry) Inode() Inode {
return d.inode
}
+// FSLocalPath returns an absolute path to d, relative to the root of its
+// filesystem.
+func (d *Dentry) FSLocalPath() string {
+ var b fspath.Builder
+ _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b)
+ b.PrependByte('/')
+ return b.String()
+}
+
// The Inode interface maps filesystem-level operations that operate on paths to
// equivalent operations on specific filesystem nodes.
//
diff --git a/pkg/sentry/fsimpl/kernfs/mmap_util.go b/pkg/sentry/fsimpl/kernfs/mmap_util.go
index bd6a134b4..d1539d904 100644
--- a/pkg/sentry/fsimpl/kernfs/mmap_util.go
+++ b/pkg/sentry/fsimpl/kernfs/mmap_util.go
@@ -16,11 +16,11 @@ package kernfs
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// inodePlatformFile implements memmap.File. It exists solely because inode
@@ -66,7 +66,7 @@ func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
}
// MapInternal implements memmap.File.MapInternal.
-func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) {
return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
}
@@ -100,7 +100,7 @@ func (i *CachedMappable) Init(hostFD int) {
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
i.mapsMu.Lock()
mapped := i.mappings.AddMapping(ms, ar, offset, writable)
for _, r := range mapped {
@@ -111,7 +111,7 @@ func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace,
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
i.mapsMu.Lock()
unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable)
for _, r := range unmapped {
@@ -121,19 +121,19 @@ func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpa
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
return i.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
mr := optional
return []memmap.Translation{
{
Source: mr,
File: &i.pf,
Offset: mr.Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
},
}, nil
}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
index bf13bbbf4..5504476c8 100644
--- a/pkg/sentry/fsimpl/overlay/BUILD
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 27b00cf6f..45aa5a494 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -21,11 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
func (d *dentry) isCopiedUp() bool {
@@ -138,8 +138,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
// We may have memory mappings of the file on the lower layer.
// Switch to mapping the file on the upper layer instead.
mmapOpts = &memmap.MMapOpts{
- Perms: usermem.ReadWrite,
- MaxPerms: usermem.ReadWrite,
+ Perms: hostarch.ReadWrite,
+ MaxPerms: hostarch.ReadWrite,
}
if err := newFD.ConfigureMMap(ctx, mmapOpts); err != nil {
cleanupUndoCopyUp()
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index d791c06db..43bfd69a3 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -445,7 +446,7 @@ func (fd *regularFileFD) ensureMappable(ctx context.Context, opts *memmap.MMapOp
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
d.mapsMu.Lock()
defer d.mapsMu.Unlock()
if err := d.wrappedMappable.AddMapping(ctx, ms, ar, offset, writable); err != nil {
@@ -458,7 +459,7 @@ func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar user
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
d.mapsMu.Lock()
defer d.mapsMu.Unlock()
d.wrappedMappable.RemoveMapping(ctx, ms, ar, offset, writable)
@@ -468,7 +469,7 @@ func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar u
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
d.mapsMu.Lock()
defer d.mapsMu.Unlock()
if err := d.wrappedMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil {
@@ -481,7 +482,7 @@ func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR,
}
// Translate implements memmap.Mappable.Translate.
-func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
d.dataMu.RLock()
defer d.dataMu.RUnlock()
return d.wrappedMappable.Translate(ctx, required, optional, at)
diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD
index 5950a2d59..278ee3c92 100644
--- a/pkg/sentry/fsimpl/pipefs/BUILD
+++ b/pkg/sentry/fsimpl/pipefs/BUILD
@@ -10,12 +10,12 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 3f05e444e..08aedc2ad 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -22,13 +22,13 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// +stateify savable
@@ -131,7 +131,7 @@ func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOpti
ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds())
return linux.Statx{
Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
- Blksize: usermem.PageSize,
+ Blksize: hostarch.PageSize,
Nlink: 1,
UID: uint32(i.uid),
GID: uint32(i.gid),
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index d47a4fff9..2b628bd55 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -81,6 +81,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 254a8b062..ce8f55b1f 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -86,13 +86,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
procfs.MaxCachedDentries = maxCachedDentries
procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
- var cgroups map[string]string
+ var fakeCgroupControllers map[string]string
if opts.InternalData != nil {
data := opts.InternalData.(*InternalData)
- cgroups = data.Cgroups
+ fakeCgroupControllers = data.Cgroups
}
- inode := procfs.newTasksInode(ctx, k, pidns, cgroups)
+ inode := procfs.newTasksInode(ctx, k, pidns, fakeCgroupControllers)
var dentry kernfs.Dentry
dentry.InitRoot(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index fea138f93..d05cc1508 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -47,7 +47,7 @@ type taskInode struct {
var _ kernfs.Inode = (*taskInode)(nil)
-func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) {
+func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, fakeCgroupControllers map[string]string) (kernfs.Inode, error) {
if task.ExitState() == kernel.TaskExitDead {
return nil, syserror.ESRCH
}
@@ -82,10 +82,12 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns
"uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
- contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers)
+ contents["task"] = fs.newSubtasks(ctx, task, pidns, fakeCgroupControllers)
}
- if len(cgroupControllers) > 0 {
- contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
+ if len(fakeCgroupControllers) > 0 {
+ contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newFakeCgroupData(fakeCgroupControllers))
+ } else {
+ contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskCgroupData{task: task})
}
taskInode := &taskInode{task: task}
@@ -226,11 +228,14 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
return &ioData{ioUsage: t}
}
-// newCgroupData creates inode that shows cgroup information.
-// From man 7 cgroups: "For each cgroup hierarchy of which the process is a
-// member, there is one entry containing three colon-separated fields:
-// hierarchy-ID:controller-list:cgroup-path"
-func newCgroupData(controllers map[string]string) dynamicInode {
+// newFakeCgroupData creates an inode that shows fake cgroup
+// information passed in as mount options. From man 7 cgroups: "For
+// each cgroup hierarchy of which the process is a member, there is
+// one entry containing three colon-separated fields:
+// hierarchy-ID:controller-list:cgroup-path"
+//
+// TODO(b/182488796): Remove once all users adopt cgroupfs.
+func newFakeCgroupData(controllers map[string]string) dynamicInode {
var buf bytes.Buffer
// The hierarchy ids must be positive integers (for cgroup v1), but the
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index fdae163d1..b294dfd6a 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -122,8 +123,8 @@ func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
buf.Grow((len(auxv) + 1) * 16)
for _, e := range auxv {
var tmp [16]byte
- usermem.ByteOrder.PutUint64(tmp[:8], e.Key)
- usermem.ByteOrder.PutUint64(tmp[8:], uint64(e.Value))
+ hostarch.ByteOrder.PutUint64(tmp[:8], e.Key)
+ hostarch.ByteOrder.PutUint64(tmp[8:], uint64(e.Value))
buf.Write(tmp[:])
}
var atNull [16]byte
@@ -168,15 +169,15 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
defer m.DecUsers(ctx)
// Figure out the bounds of the exec arg we are trying to read.
- var ar usermem.AddrRange
+ var ar hostarch.AddrRange
switch d.arg {
case cmdlineDataArg:
- ar = usermem.AddrRange{
+ ar = hostarch.AddrRange{
Start: m.ArgvStart(),
End: m.ArgvEnd(),
}
case environDataArg:
- ar = usermem.AddrRange{
+ ar = hostarch.AddrRange{
Start: m.EnvvStart(),
End: m.EnvvEnd(),
}
@@ -192,7 +193,7 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading
// cmdline and environment").
writer := &bufferWriter{buf: buf}
- if n, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil {
+ if n, err := m.CopyInTo(ctx, hostarch.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil {
// Nothing to copy or something went wrong.
return err
}
@@ -209,7 +210,7 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
// There is no NULL terminator in the string, return into envp.
- arEnvv := usermem.AddrRange{
+ arEnvv := hostarch.AddrRange{
Start: m.EnvvStart(),
End: m.EnvvEnd(),
}
@@ -218,11 +219,11 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
// we'll return one page total between argv and envp because of the
// above page restrictions.
- if buf.Len() >= usermem.PageSize {
+ if buf.Len() >= hostarch.PageSize {
// Returned at least one page already, nothing else to add.
return nil
}
- remaining := usermem.PageSize - buf.Len()
+ remaining := hostarch.PageSize - buf.Len()
if int(arEnvv.Length()) > remaining {
end, ok := arEnvv.Start.AddLength(uint64(remaining))
if !ok {
@@ -230,7 +231,7 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
arEnvv.End = end
}
- if _, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil {
+ if _, err := m.CopyInTo(ctx, hostarch.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil {
return err
}
@@ -323,7 +324,7 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in
// the system page size, and the write must be performed at the start of
// the file ..." - user_namespaces(7)
srclen := src.NumBytes()
- if srclen >= usermem.PageSize || offset != 0 {
+ if srclen >= hostarch.PageSize || offset != 0 {
return 0, syserror.EINVAL
}
b := make([]byte, srclen)
@@ -481,7 +482,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64
defer m.DecUsers(ctx)
// Buffer the read data because of MM locks
buf := make([]byte, dst.NumBytes())
- n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ n, readErr := m.CopyIn(ctx, hostarch.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
if n > 0 {
if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
return 0, syserror.EFAULT
@@ -613,7 +614,7 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
rss = mm.ResidentSetSize()
}
})
- fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
+ fmt.Fprintf(buf, "%d %d ", vss, rss/hostarch.PageSize)
// rsslim.
fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Limits().Get(limits.Rss).Cur)
@@ -655,7 +656,7 @@ func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
})
- fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+ fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize)
return nil
}
@@ -774,7 +775,7 @@ func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset
}
// Limit input size so as not to impact performance if input size is large.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
@@ -1099,3 +1100,32 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err
func (fd *namespaceFD) Release(ctx context.Context) {
fd.inode.DecRef(ctx)
}
+
+// taskCgroupData generates data for /proc/[pid]/cgroup.
+//
+// +stateify savable
+type taskCgroupData struct {
+ dynamicBytesFileSetAttr
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*taskCgroupData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *taskCgroupData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // When a task is existing on Linux, a task's cgroup set is cleared and
+ // reset to the initial cgroup set, which is essentially the set of root
+ // cgroups. Because of this, the /proc/<pid>/cgroup file is always readable
+ // on Linux throughout a task's lifetime.
+ //
+ // The sentry removes tasks from cgroups during the exit process, but
+ // doesn't move them into an initial cgroup set, so partway through task
+ // exit this file show a task is in no cgroups, which is incorrect. Instead,
+ // once a task has left its cgroups, we return an error.
+ if d.task.ExitState() >= kernel.TaskExitInitiated {
+ return syserror.ESRCH
+ }
+
+ d.task.GenerateProcTaskCgroup(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index d4f6a5a9b..177cb828f 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -34,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/usermem"
)
func (fs *filesystem) newTaskNetDir(ctx context.Context, task *kernel.Task) kernfs.Inode {
@@ -295,7 +295,7 @@ func networkToHost16(n uint16) uint16 {
// binary.BigEndian.Uint16() require a read of binary.BigEndian and an
// interface method call, defeating inlining.
buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
- return usermem.ByteOrder.Uint16(buf[:])
+ return hostarch.ByteOrder.Uint16(buf[:])
}
func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
@@ -317,14 +317,14 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
// __be32 which is a typedef for an unsigned int, and is printed with
// %X. This means that for a little-endian machine, Linux prints the
// least-significant byte of the address first. To emulate this, we first
- // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // invert the byte order for the address using hostarch.ByteOrder.Uint32,
// which makes it have the equivalent encoding to a __be32 on a little
// endian machine. Note that this operation is a no-op on a big endian
// machine. Then similar to Linux, we format it with %X, which will print
// the most-significant byte of the __be32 address first, which is now
// actually the least-significant byte of the original address in
// linux.SockAddrInet.Addr on little endian machines, due to the conversion.
- addr := usermem.ByteOrder.Uint32(a.Addr[:])
+ addr := hostarch.ByteOrder.Uint32(a.Addr[:])
fmt.Fprintf(w, "%08X:%04X ", addr, port)
case linux.AF_INET6:
@@ -334,10 +334,10 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
}
port := networkToHost16(a.Port)
- addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4])
- addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8])
- addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12])
- addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16])
+ addr0 := hostarch.ByteOrder.Uint32(a.Addr[0:4])
+ addr1 := hostarch.ByteOrder.Uint32(a.Addr[4:8])
+ addr2 := hostarch.ByteOrder.Uint32(a.Addr[8:12])
+ addr3 := hostarch.ByteOrder.Uint32(a.Addr[12:16])
fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port)
}
}
@@ -739,10 +739,10 @@ func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error {
)
if len(rt.GatewayAddr) == header.IPv4AddressSize {
flags |= linux.RTF_GATEWAY
- gw = usermem.ByteOrder.Uint32(rt.GatewayAddr)
+ gw = hostarch.ByteOrder.Uint32(rt.GatewayAddr)
}
if len(rt.DstAddr) == header.IPv4AddressSize {
- prefix = usermem.ByteOrder.Uint32(rt.DstAddr)
+ prefix = hostarch.ByteOrder.Uint32(rt.DstAddr)
}
l := fmt.Sprintf(
"%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d",
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index fdc580610..7c7543f14 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -54,15 +54,15 @@ type tasksInode struct {
// '/proc/self' and '/proc/thread-self' have custom directory offsets in
// Linux. So handle them outside of OrderedChildren.
- // cgroupControllers is a map of controller name to directory in the
+ // fakeCgroupControllers is a map of controller name to directory in the
// cgroup hierarchy. These controllers are immutable and will be listed
// in /proc/pid/cgroup if not nil.
- cgroupControllers map[string]string
+ fakeCgroupControllers map[string]string
}
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
+func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]kernfs.Inode{
"cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))),
@@ -76,11 +76,16 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns
"uptime": fs.newInode(ctx, root, 0444, &uptimeData{}),
"version": fs.newInode(ctx, root, 0444, &versionData{}),
}
+ // If fakeCgroupControllers are provided, don't create a cgroupfs backed
+ // /proc/cgroup as it will not match the fake controllers.
+ if len(fakeCgroupControllers) == 0 {
+ contents["cgroups"] = fs.newInode(ctx, root, 0444, &cgroupsData{})
+ }
inode := &tasksInode{
- pidns: pidns,
- fs: fs,
- cgroupControllers: cgroupControllers,
+ pidns: pidns,
+ fs: fs,
+ fakeCgroupControllers: fakeCgroupControllers,
}
inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.InitRefs()
@@ -118,7 +123,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
return nil, syserror.ENOENT
}
- return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers)
+ return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers)
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 01b7a6678..e1a8b4409 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// +stateify savable
@@ -270,7 +270,7 @@ func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
anon := snapshot.Anonymous + snapshot.Tmpfs
file := snapshot.PageCache + snapshot.Mapped
// We don't actually have active/inactive LRUs, so just make up numbers.
- activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ activeFile := (file / 2) &^ (hostarch.PageSize - 1)
inactiveFile := file - activeFile
fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
@@ -384,3 +384,19 @@ func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error
k.VFS().GenerateProcFilesystems(buf)
return nil
}
+
+// cgroupsData backs /proc/cgroups.
+//
+// +stateify savable
+type cgroupsData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*cgroupsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ r := kernel.KernelFromContext(ctx).CgroupRegistry()
+ r.GenerateProcCgroups(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index fb274b78e..9b14dd6b9 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -214,7 +215,7 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset
}
// Limit the amount of memory allocated.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
@@ -262,7 +263,7 @@ func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, off
}
// Limit the amount of memory allocated.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
@@ -318,7 +319,7 @@ func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset i
defer d.mu.Unlock()
// Limit the amount of memory allocated.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
size, err := d.readSizeLocked()
if err != nil {
return 0, err
@@ -406,7 +407,7 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs
}
// Limit input size so as not to impact performance if input size is large.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
@@ -463,7 +464,7 @@ func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset i
// Limit input size so as not to impact performance if input size is
// large.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
ports := make([]int32, 2)
n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts)
diff --git a/pkg/sentry/fsimpl/proc/yama.go b/pkg/sentry/fsimpl/proc/yama.go
index aebfe8944..e039ec45e 100644
--- a/pkg/sentry/fsimpl/proc/yama.go
+++ b/pkg/sentry/fsimpl/proc/yama.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -62,7 +63,7 @@ func (s *yamaPtraceScope) Write(ctx context.Context, src usermem.IOSequence, off
}
// Limit the amount of memory allocated.
- src = src.TakeFirst(usermem.PageSize - 1)
+ src = src.TakeFirst(hostarch.PageSize - 1)
var v int32
n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 1d9280dae..14eb10dcd 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -122,11 +122,11 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs
}
func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
- // If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs
- // should be mounted at debug/, but for our purposes, it is sufficient to
- // keep it in sys.
+ // Set up /sys/kernel/debug/kcov. Technically, debugfs should be
+ // mounted at debug/, but for our purposes, it is sufficient to keep it
+ // in sys.
var children map[string]kernfs.Inode
- if coverage.KcovAvailable() {
+ if coverage.KcovSupported() {
log.Debugf("Set up /sys/kernel/debug/kcov")
children = map[string]kernfs.Inode{
"debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index 400a97996..b3f9d1010 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -15,6 +15,7 @@ go_library(
"//pkg/context",
"//pkg/cpuid",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/memutil",
"//pkg/sentry/fsbridge",
"//pkg/sentry/fsimpl/tmpfs",
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
index 1a8525b06..59e6f9c92 100644
--- a/pkg/sentry/fsimpl/testutil/testutil.go
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// System represents the context for a single test.
@@ -105,7 +107,7 @@ func (s *System) Destroy() {
// ReadToEnd reads the contents of fd until EOF to a string.
func (s *System) ReadToEnd(fd *vfs.FileDescription) (string, error) {
- buf := make([]byte, usermem.PageSize)
+ buf := make([]byte, hostarch.PageSize)
bufIOSeq := usermem.BytesIOSequence(buf)
opts := vfs.ReadOptions{}
diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD
index fbb02a271..7ce7dc429 100644
--- a/pkg/sentry/fsimpl/timerfd/BUILD
+++ b/pkg/sentry/fsimpl/timerfd/BUILD
@@ -8,6 +8,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/kernel/time",
"//pkg/sentry/vfs",
"//pkg/syserror",
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
index 64d33c3a8..cbb8b67c5 100644
--- a/pkg/sentry/fsimpl/timerfd/timerfd.go
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -72,7 +73,7 @@ func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequenc
}
if val := atomic.SwapUint64(&tfd.val, 0); val != 0 {
var buf [sizeofUint64]byte
- usermem.ByteOrder.PutUint64(buf[:], val)
+ hostarch.ByteOrder.PutUint64(buf[:], val)
if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
// Linux does not undo consuming the number of
// expirations even if writing to userspace fails.
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 09957c2b7..e21fddd7f 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -59,6 +59,7 @@ go_library(
"//pkg/amutex",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index a6d161882..c45bddff6 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -224,7 +225,7 @@ func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
rf.mapsMu.Lock()
defer rf.mapsMu.Unlock()
rf.dataMu.RLock()
@@ -240,7 +241,7 @@ func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, a
pagesBefore := rf.writableMappingPages
// ar is guaranteed to be page aligned per memmap.Mappable.
- rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+ rf.writableMappingPages += uint64(ar.Length() / hostarch.PageSize)
if rf.writableMappingPages < pagesBefore {
panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
@@ -251,7 +252,7 @@ func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, a
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
rf.mapsMu.Lock()
defer rf.mapsMu.Unlock()
@@ -261,7 +262,7 @@ func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace
pagesBefore := rf.writableMappingPages
// ar is guaranteed to be page aligned per memmap.Mappable.
- rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+ rf.writableMappingPages -= uint64(ar.Length() / hostarch.PageSize)
if rf.writableMappingPages > pagesBefore {
panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
@@ -270,12 +271,12 @@ func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
return rf.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
rf.dataMu.Lock()
defer rf.dataMu.Unlock()
@@ -307,7 +308,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.
Source: segMR,
File: rf.memFile,
Offset: seg.FileRangeOf(segMR).Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
})
translatedEnd = segMR.End
}
@@ -487,6 +488,7 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
file := fd.inode().impl.(*regularFile)
+ opts.SentryOwnedContent = true
return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
}
@@ -539,7 +541,7 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
switch {
case seg.Ok():
// Get internal mappings.
- ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Read)
if err != nil {
return done, err
}
@@ -608,7 +610,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
//
// See Linux, mm/filemap.c:generic_perform_write() and
// mm/shmem.c:shmem_write_begin().
- if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart {
+ if pgstart := uint64(hostarch.Addr(rw.file.size).RoundDown()); end > pgstart {
end = pgstart
}
if end <= rw.off {
@@ -619,8 +621,8 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
// Page-aligned mr for when we need to allocate memory. RoundUp can't
// overflow since end is an int64.
- pgstartaddr := usermem.Addr(rw.off).RoundDown()
- pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgstartaddr := hostarch.Addr(rw.off).RoundDown()
+ pgendaddr, _ := hostarch.Addr(end).RoundUp()
pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
var (
@@ -633,7 +635,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
switch {
case seg.Ok():
// Get internal mappings.
- ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), hostarch.Write)
if err != nil {
retErr = err
goto exitLoop
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 8df81f589..9ae25ce9e 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -36,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -43,7 +44,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Name is the default filesystem name.
@@ -252,8 +252,8 @@ func (d *dentry) releaseChildrenLocked(ctx context.Context) {
// immutable
var globalStatfs = linux.Statfs{
Type: linux.TMPFS_MAGIC,
- BlockSize: usermem.PageSize,
- FragmentSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
+ FragmentSize: hostarch.PageSize,
NameLength: linux.NAME_MAX,
// tmpfs currently does not support configurable size limits. In Linux,
@@ -263,9 +263,9 @@ var globalStatfs = linux.Statfs{
// chosen to ensure that BlockSize * Blocks does not overflow int64 (which
// applications may also handle incorrectly).
// TODO(b/29637826): allow configuring a tmpfs size and enforce it.
- Blocks: math.MaxInt64 / usermem.PageSize,
- BlocksFree: math.MaxInt64 / usermem.PageSize,
- BlocksAvailable: math.MaxInt64 / usermem.PageSize,
+ Blocks: math.MaxInt64 / hostarch.PageSize,
+ BlocksFree: math.MaxInt64 / hostarch.PageSize,
+ BlocksAvailable: math.MaxInt64 / hostarch.PageSize,
}
// dentry implements vfs.DentryImpl.
@@ -485,7 +485,7 @@ func (i *inode) statTo(stat *linux.Statx) {
linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE |
linux.STATX_BLOCKS | linux.STATX_ATIME | linux.STATX_CTIME |
linux.STATX_MTIME
- stat.Blksize = usermem.PageSize
+ stat.Blksize = hostarch.PageSize
stat.Nlink = atomic.LoadUint32(&i.nlink)
stat.UID = atomic.LoadUint32(&i.uid)
stat.GID = atomic.LoadUint32(&i.gid)
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index e265be0ee..d473a922d 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -14,13 +14,16 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/marshal/primitive",
"//pkg/merkletree",
"//pkg/refsvfs2",
+ "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 6cb1a23e0..ca8090bbf 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -200,7 +200,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -209,7 +209,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// unexpected modifications to the file system.
offset, err := strconv.Atoi(off)
if err != nil {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
// Open parent Merkle tree file to read and verify child's hash.
@@ -223,12 +223,14 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
return nil, err
}
+ defer parentMerkleFD.DecRef(ctx)
+
// dataSize is the size of raw data for the Merkle tree. For a file,
// dataSize is the size of the whole file. For a directory, dataSize is
// the size of all its children's hashes.
@@ -241,7 +243,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -251,7 +253,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
// unexpected modifications to the file system.
parentSize, err := strconv.Atoi(dataSize)
if err != nil {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := FileReadWriteSeeker{
@@ -264,7 +266,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
Start: parent.lowerVD,
}, &vfs.StatOptions{})
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -294,7 +296,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
})
parent.hashMu.RUnlock()
if err != nil && err != io.EOF {
- return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
@@ -331,19 +333,21 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
Flags: linux.O_RDONLY,
})
if err == syserror.ENOENT {
- return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
}
if err != nil {
return err
}
+ defer fd.DecRef(ctx)
+
merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{
Name: merkleSizeXattr,
Size: sizeOfStringInt32,
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return err
@@ -351,7 +355,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
size, err := strconv.Atoi(merkleSize)
if err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
if d.isDir() && len(d.childrenNames) == 0 {
@@ -361,14 +365,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err))
}
if err != nil {
return err
}
childrenOffset, err := strconv.Atoi(childrenOffString)
if err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
}
childrenSizeString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{
@@ -377,23 +381,23 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err))
}
if err != nil {
return err
}
childrenSize, err := strconv.Atoi(childrenSizeString)
if err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
}
childrenNames := make([]byte, childrenSize)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(childrenOffset), vfs.ReadOptions{}); err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err))
}
if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil {
- return alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err))
}
}
@@ -438,7 +442,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
}
if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
- return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ return fs.alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
}
d.mode = uint32(stat.Mode)
d.uid = stat.UID
@@ -471,7 +475,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// The file was previously accessed. If the
// file does not exist now, it indicates an
// unexpected modification to the file system.
- return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path))
}
if err != nil {
return nil, err
@@ -483,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// does not exist now, it indicates an unexpected
// modification to the file system.
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
}
if err != nil {
return nil, err
@@ -553,8 +557,8 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
}
childVD, err := parent.getLowerAt(ctx, vfsObj, name)
- if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name))
+ if parent.verityEnabled() && err == syserror.ENOENT {
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name))
}
if err != nil {
return nil, err
@@ -565,30 +569,31 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
defer childVD.DecRef(ctx)
childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
- if err == syserror.ENOENT {
- if !fs.allowRuntimeEnable {
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name))
- }
- childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(merklePrefix + name),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR | linux.O_CREAT,
- Mode: 0644,
- })
- if err != nil {
- return nil, err
- }
- childMerkleFD.DecRef(ctx)
- childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
- if err != nil {
+ if err != nil {
+ if err == syserror.ENOENT {
+ if parent.verityEnabled() {
+ return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name))
+ }
+ childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parent.lowerVD,
+ Start: parent.lowerVD,
+ Path: fspath.Parse(merklePrefix + name),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT,
+ Mode: 0644,
+ })
+ if err != nil {
+ return nil, err
+ }
+ childMerkleFD.DecRef(ctx)
+ childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
+ if err != nil {
+ return nil, err
+ }
+ } else {
return nil, err
}
}
- if err != nil && err != syserror.ENOENT {
- return nil, err
- }
// Clear the Merkle tree file if they are to be generated at runtime.
// TODO(b/182315468): Optimize the Merkle tree generate process to
@@ -632,8 +637,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
childVD.IncRef()
childMerkleVD.IncRef()
- parent.IncRef()
- child.parent = parent
child.name = name
child.mode = uint32(stat.Mode)
@@ -657,6 +660,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
}
}
+ parent.IncRef()
+ child.parent = parent
+
return child, nil
}
@@ -855,7 +861,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// missing, it indicates an unexpected modification to the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
}
return nil, err
}
@@ -878,7 +884,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -903,7 +909,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
})
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -921,7 +927,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
if err != nil {
if err == syserror.ENOENT {
parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
}
return nil, err
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 0d9b0ee2c..458c7fcb6 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -34,6 +34,8 @@
package verity
import (
+ "bytes"
+ "encoding/hex"
"encoding/json"
"fmt"
"math"
@@ -44,13 +46,16 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/merkletree"
"gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -93,14 +98,18 @@ const (
)
var (
- // action specifies the action towards detected violation.
- action ViolationAction
-
// verityMu synchronizes concurrent operations that enable verity and perform
// verification checks.
verityMu sync.RWMutex
)
+// Mount option names for verityfs.
+const (
+ moptLowerPath = "lower_path"
+ moptRootHash = "root_hash"
+ moptRootName = "root_name"
+)
+
// HashAlgorithm is a type specifying the algorithm used to hash the file
// content.
type HashAlgorithm int
@@ -167,6 +176,12 @@ type filesystem struct {
// system.
alg HashAlgorithm
+ // action specifies the action towards detected violation.
+ action ViolationAction
+
+ // opts is the string mount options passed to opts.Data.
+ opts string
+
// renameMu synchronizes renaming with non-renaming operations in order
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
@@ -189,9 +204,6 @@ type filesystem struct {
//
// +stateify savable
type InternalFilesystemOptions struct {
- // RootMerkleFileName is the name of the verity root Merkle tree file.
- RootMerkleFileName string
-
// LowerName is the name of the filesystem wrapped by verity fs.
LowerName string
@@ -199,9 +211,6 @@ type InternalFilesystemOptions struct {
// system.
Alg HashAlgorithm
- // RootHash is the root hash of the overall verity file system.
- RootHash []byte
-
// AllowRuntimeEnable specifies whether the verity file system allows
// enabling verification for files (i.e. building Merkle trees) during
// runtime.
@@ -226,8 +235,8 @@ func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In ErrorOnViolation
// mode, it returns EIO, otherwise it panic.
-func alertIntegrityViolation(msg string) error {
- if action == ErrorOnViolation {
+func (fs *filesystem) alertIntegrityViolation(msg string) error {
+ if fs.action == ErrorOnViolation {
return syserror.EIO
}
panic(msg)
@@ -235,28 +244,99 @@ func alertIntegrityViolation(msg string) error {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ var rootHash []byte
+ if encodedRootHash, ok := mopts[moptRootHash]; ok {
+ delete(mopts, moptRootHash)
+ hash, err := hex.DecodeString(encodedRootHash)
+ if err != nil {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err)
+ return nil, nil, syserror.EINVAL
+ }
+ rootHash = hash
+ }
+ var lowerPathname string
+ if path, ok := mopts[moptLowerPath]; ok {
+ delete(mopts, moptLowerPath)
+ lowerPathname = path
+ }
+ rootName := "root"
+ if root, ok := mopts[moptRootName]; ok {
+ delete(mopts, moptRootName)
+ rootName = root
+ }
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Handle internal options.
iopts, ok := opts.InternalData.(InternalFilesystemOptions)
- if !ok {
+ if len(lowerPathname) == 0 && !ok {
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
return nil, nil, syserror.EINVAL
}
- action = iopts.Action
-
- // Mount the lower file system. The lower file system is wrapped inside
- // verity, and should not be exposed or connected.
- mopts := &vfs.MountOptions{
- GetFilesystemOptions: iopts.LowerGetFSOptions,
- InternalMount: true,
+ if len(lowerPathname) != 0 {
+ if ok {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path")
+ return nil, nil, syserror.EINVAL
+ }
+ iopts = InternalFilesystemOptions{
+ AllowRuntimeEnable: len(rootHash) == 0,
+ Action: ErrorOnViolation,
+ }
}
- mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts)
- if err != nil {
- return nil, nil, err
+
+ var lowerMount *vfs.Mount
+ var mountedLowerVD vfs.VirtualDentry
+ // Use an existing mount if lowerPath is provided.
+ if len(lowerPathname) != 0 {
+ vfsroot := vfs.RootFromContext(ctx)
+ if vfsroot.Ok() {
+ defer vfsroot.DecRef(ctx)
+ }
+ lowerPath := fspath.Parse(lowerPathname)
+ if !lowerPath.Absolute {
+ ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ var err error
+ mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: lowerPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Infof("verity.FilesystemType.GetFilesystem: failed to resolve lower_path %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ lowerMount = mountedLowerVD.Mount()
+ defer mountedLowerVD.DecRef(ctx)
+ } else {
+ // Mount the lower file system. The lower file system is wrapped inside
+ // verity, and should not be exposed or connected.
+ mountOpts := &vfs.MountOptions{
+ GetFilesystemOptions: iopts.LowerGetFSOptions,
+ InternalMount: true,
+ }
+ mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mountOpts)
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerMount = mnt
}
fs := &filesystem{
creds: creds.Fork(),
alg: iopts.Alg,
- lowerMount: mnt,
+ lowerMount: lowerMount,
+ action: iopts.Action,
+ opts: opts.Data,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
@@ -264,11 +344,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Construct the root dentry.
d := fs.newDentry()
d.refs = 1
- lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root())
+ lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root())
lowerVD.IncRef()
d.lowerVD = lowerVD
- rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName
+ rootMerkleName := merkleRootPrefix + rootName
lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
@@ -309,7 +389,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// the root Merkle file, or it's never generated.
fs.vfsfs.DecRef(ctx)
d.DecRef(ctx)
- return nil, nil, alertIntegrityViolation("Failed to find root Merkle file")
+ return nil, nil, fs.alertIntegrityViolation("Failed to find root Merkle file")
}
// Clear the Merkle tree file if they are to be generated at runtime.
@@ -348,9 +428,15 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
- d.hash = make([]byte, len(iopts.RootHash))
d.childrenNames = make(map[string]struct{})
+ d.hashMu.Lock()
+ d.hash = make([]byte, len(rootHash))
+ copy(d.hash, rootHash)
+ d.hashMu.Unlock()
+
+ fs.rootDentry = d
+
if !d.isDir() {
ctx.Warningf("verity root must be a directory")
return nil, nil, syserror.EINVAL
@@ -366,7 +452,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Size: sizeOfStringInt32,
})
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err))
}
if err != nil {
return nil, nil, err
@@ -374,7 +460,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
off, err := strconv.Atoi(offString)
if err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err))
}
sizeString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{
@@ -385,14 +471,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Size: sizeOfStringInt32,
})
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err))
}
if err != nil {
return nil, nil, err
}
size, err := strconv.Atoi(sizeString)
if err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err))
}
lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
@@ -402,19 +488,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
Flags: linux.O_RDONLY,
})
if err == syserror.ENOENT {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err))
}
if err != nil {
return nil, nil, err
}
+ defer lowerMerkleFD.DecRef(ctx)
+
childrenNames := make([]byte, size)
if _, err := lowerMerkleFD.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(off), vfs.ReadOptions{}); err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err))
}
if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil {
- return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err))
+ return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err))
}
if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil {
@@ -422,13 +510,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
- d.hashMu.Lock()
- copy(d.hash, iopts.RootHash)
- d.hashMu.Unlock()
d.vfsd.Init(d)
- fs.rootDentry = d
-
return &fs.vfsfs, &d.vfsd, nil
}
@@ -439,7 +522,7 @@ func (fs *filesystem) Release(ctx context.Context) {
// MountOptions implements vfs.FilesystemImpl.MountOptions.
func (fs *filesystem) MountOptions() string {
- return ""
+ return fs.opts
}
// dentry implements vfs.DentryImpl.
@@ -720,6 +803,10 @@ type fileDescription struct {
// underlying file system.
lowerFD *vfs.FileDescription
+ // lowerMappable is the memmap.Mappable corresponding to this file in the
+ // underlying file system.
+ lowerMappable memmap.Mappable
+
// merkleReader is the read-only FileDescription corresponding to the
// Merkle tree file in the underlying file system.
merkleReader *vfs.FileDescription
@@ -792,7 +879,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
// Verify that the child is expected.
if dirent.Name != "." && dirent.Name != ".." {
if _, ok := fd.d.childrenNames[dirent.Name]; !ok {
- return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name))
+ return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name))
}
}
}
@@ -806,7 +893,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa
// The result should contain all children plus "." and "..".
if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 {
- return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
+ return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
}
for fd.off < int64(len(ds)) {
@@ -978,7 +1065,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
// or directory other than the root, the parent Merkle tree file should
// have also been initialized.
if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
- return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
+ return 0, fd.d.fs.alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
hash, dataSize, err := fd.generateMerkleLocked(ctx)
@@ -1033,7 +1120,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
}
// measureVerity returns the hash of fd, saved in verityDigest.
-func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest usermem.Addr) (uintptr, error) {
+func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest hostarch.Addr) (uintptr, error) {
t := kernel.TaskFromContext(ctx)
if t == nil {
return 0, syserror.EINVAL
@@ -1051,7 +1138,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
if fd.d.fs.allowRuntimeEnable {
return 0, syserror.ENODATA
}
- return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found")
+ return 0, fd.d.fs.alertIntegrityViolation("Ioctl measureVerity: no hash found")
}
// The first part of VerityDigest is the metadata.
@@ -1072,11 +1159,11 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm
}
// Now copy the root hash bytes to the memory after metadata.
- _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.hash)
+ _, err := t.CopyOutBytes(hostarch.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.hash)
return 0, err
}
-func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) {
+func (fd *fileDescription) verityFlags(ctx context.Context, flags hostarch.Addr) (uintptr, error) {
f := int32(0)
fd.d.hashMu.RLock()
@@ -1141,7 +1228,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
if err == syserror.ENODATA {
- return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
return 0, err
@@ -1151,7 +1238,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// unexpected modifications to the file system.
size, err := strconv.Atoi(dataSize)
if err != nil {
- return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
dataReader := FileReadWriteSeeker{
@@ -1184,7 +1271,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
})
fd.d.hashMu.RUnlock()
if err != nil {
- return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
+ return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
@@ -1199,6 +1286,24 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op
return 0, syserror.EROFS
}
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ if err := fd.lowerFD.ConfigureMMap(ctx, opts); err != nil {
+ return err
+ }
+ fd.lowerMappable = opts.Mappable
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef(ctx)
+ opts.MappingIdentity = nil
+ }
+
+ // Check if mmap is allowed on the lower filesystem.
+ if !opts.SentryOwnedContent {
+ return syserror.ENODEV
+ }
+ return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts)
+}
+
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
return fd.lowerFD.LockBSD(ctx, ownerPID, t, block)
@@ -1224,6 +1329,115 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t
return fd.lowerFD.TestPOSIX(ctx, uid, t, r)
}
+// Translate implements memmap.Mappable.Translate.
+func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
+ ts, err := fd.lowerMappable.Translate(ctx, required, optional, at)
+ if err != nil {
+ return ts, err
+ }
+
+ // dataSize is the size of the whole file.
+ dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfStringInt32,
+ })
+
+ // The Merkle tree file for the child should have been created and
+ // contains the expected xattrs. If the xattr does not exist, it
+ // indicates unexpected modifications to the file system.
+ if err == syserror.ENODATA {
+ return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ }
+ if err != nil {
+ return ts, err
+ }
+
+ // The dataSize xattr should be an integer. If it's not, it indicates
+ // unexpected modifications to the file system.
+ size, err := strconv.Atoi(dataSize)
+ if err != nil {
+ return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ }
+
+ merkleReader := FileReadWriteSeeker{
+ FD: fd.merkleReader,
+ Ctx: ctx,
+ }
+
+ for _, t := range ts {
+ // Content integrity relies on sentry owning the backing data. MapInternal is guaranteed
+ // to fetch sentry owned memory because we disallow verity mmaps otherwise.
+ ims, err := t.File.MapInternal(memmap.FileRange{t.Offset, t.Offset + t.Source.Length()}, hostarch.Read)
+ if err != nil {
+ return nil, err
+ }
+ dataReader := mmapReadSeeker{ims, t.Source.Start}
+ var buf bytes.Buffer
+ _, err = merkletree.Verify(&merkletree.VerifyParams{
+ Out: &buf,
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
+ ReadOffset: int64(t.Source.Start),
+ ReadSize: int64(t.Source.Length()),
+ Expected: fd.d.hash,
+ DataAndTreeInSameFile: false,
+ })
+ if err != nil {
+ return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
+ }
+ }
+ return ts, err
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (fd *fileDescription) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.lowerMappable.AddMapping(ctx, ms, ar, offset, writable)
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (fd *fileDescription) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
+ fd.lowerMappable.RemoveMapping(ctx, ms, ar, offset, writable)
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (fd *fileDescription) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.lowerMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable)
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (fd *fileDescription) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
+// mmapReadSeeker is a helper struct used by fileDescription.Translate to pass
+// a safemem.BlockSeq pointing to the mapped region as io.ReaderAt.
+type mmapReadSeeker struct {
+ safemem.BlockSeq
+ Offset uint64
+}
+
+// ReadAt implements io.ReaderAt.ReadAt. off is the offset into the mapped file.
+func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) {
+ bs := r.BlockSeq
+ // Adjust the offset into the mapped file to get the offset into the internally
+ // mapped region.
+ readOffset := off - int64(r.Offset)
+ if readOffset < 0 {
+ return 0, syserror.EINVAL
+ }
+ bs.DropFirst64(uint64(readOffset))
+ view := bs.TakeFirst64(uint64(len(p)))
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p))
+ n, err := safemem.CopySeq(dst, view)
+ return int(n), err
+}
+
// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
type FileReadWriteSeeker struct {
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 57bd65202..5c78a0019 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -89,10 +89,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
AllowUserMount: true,
})
+ data := "root_name=" + rootMerkleFilename
mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: data,
InternalData: InternalFilesystemOptions{
- RootMerkleFileName: rootMerkleFilename,
LowerName: "tmpfs",
Alg: hashAlg,
AllowRuntimeEnable: true,
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 300b7ccce..66fa1ad40 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -13,8 +13,8 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/fd",
+ "//pkg/hostarch",
"//pkg/log",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go
index c47b96b54..285ea9050 100644
--- a/pkg/sentry/hostmm/hostmm.go
+++ b/pkg/sentry/hostmm/hostmm.go
@@ -23,8 +23,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/usermem"
)
// NotifyCurrentMemcgPressureCallback requests that f is called whenever the
@@ -88,7 +88,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
if n != sizeofUint64 {
panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
}
- val := usermem.ByteOrder.Uint64(buf[:])
+ val := hostarch.ByteOrder.Uint64(buf[:])
if val >= stopVal {
// Assume this was due to the notifier's "destructor" (the
// function returned by NotifyCurrentMemcgPressureCallback
@@ -103,7 +103,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
return func() {
rw := fd.NewReadWriter(eventFD.FD())
var buf [sizeofUint64]byte
- usermem.ByteOrder.PutUint64(buf[:], stopVal)
+ hostarch.ByteOrder.PutUint64(buf[:], stopVal)
for {
n, err := rw.Write(buf[:])
if err != nil {
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index c53e3e720..a1ec6daab 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -141,6 +141,7 @@ go_library(
srcs = [
"abstract_socket_namespace.go",
"aio.go",
+ "cgroup.go",
"context.go",
"fd_table.go",
"fd_table_refs.go",
@@ -178,6 +179,7 @@ go_library(
"task.go",
"task_acct.go",
"task_block.go",
+ "task_cgroup.go",
"task_clone.go",
"task_context.go",
"task_exec.go",
@@ -226,6 +228,7 @@ go_library(
"//pkg/eventchannel",
"//pkg/fspath",
"//pkg/goid",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
@@ -240,6 +243,7 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/fsimpl/pipefs",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/fsimpl/timerfd",
@@ -294,6 +298,7 @@ go_test(
deps = [
"//pkg/abi",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/arch",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
@@ -305,6 +310,5 @@ go_test(
"//pkg/sentry/usage",
"//pkg/sync",
"//pkg/syserror",
- "//pkg/usermem",
],
)
diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go
new file mode 100644
index 000000000..1f1c63f37
--- /dev/null
+++ b/pkg/sentry/kernel/cgroup.go
@@ -0,0 +1,281 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// InvalidCgroupHierarchyID indicates an uninitialized hierarchy ID.
+const InvalidCgroupHierarchyID uint32 = 0
+
+// CgroupControllerType is the name of a cgroup controller.
+type CgroupControllerType string
+
+// CgroupController is the common interface to cgroup controllers available to
+// the entire sentry. The controllers themselves are defined by cgroupfs.
+//
+// Callers of this interface are often unable access synchronization needed to
+// ensure returned values remain valid. Some of values returned from this
+// interface are thus snapshots in time, and may become stale. This is ok for
+// many callers like procfs.
+type CgroupController interface {
+ // Returns the type of this cgroup controller (ex "memory", "cpu"). Returned
+ // value is valid for the lifetime of the controller.
+ Type() CgroupControllerType
+
+ // Hierarchy returns the ID of the hierarchy this cgroup controller is
+ // attached to. Returned value is valid for the lifetime of the controller.
+ HierarchyID() uint32
+
+ // Filesystem returns the filesystem this controller is attached to.
+ // Returned value is valid for the lifetime of the controller.
+ Filesystem() *vfs.Filesystem
+
+ // RootCgroup returns the root cgroup for this controller. Returned value is
+ // valid for the lifetime of the controller.
+ RootCgroup() Cgroup
+
+ // NumCgroups returns the number of cgroups managed by this controller.
+ // Returned value is a snapshot in time.
+ NumCgroups() uint64
+
+ // Enabled returns whether this controller is enabled. Returned value is a
+ // snapshot in time.
+ Enabled() bool
+}
+
+// Cgroup represents a named pointer to a cgroup in cgroupfs. When a task enters
+// a cgroup, it holds a reference on the underlying dentry pointing to the
+// cgroup.
+//
+// +stateify savable
+type Cgroup struct {
+ *kernfs.Dentry
+ CgroupImpl
+}
+
+func (c *Cgroup) decRef() {
+ c.Dentry.DecRef(context.Background())
+}
+
+// Path returns the absolute path of c, relative to its hierarchy root.
+func (c *Cgroup) Path() string {
+ return c.FSLocalPath()
+}
+
+// HierarchyID returns the id of the hierarchy that contains this cgroup.
+func (c *Cgroup) HierarchyID() uint32 {
+ // Note: a cgroup is guaranteed to have at least one controller.
+ return c.Controllers()[0].HierarchyID()
+}
+
+// CgroupImpl is the common interface to cgroups.
+type CgroupImpl interface {
+ Controllers() []CgroupController
+ Enter(t *Task)
+ Leave(t *Task)
+}
+
+// hierarchy represents a cgroupfs filesystem instance, with a unique set of
+// controllers attached to it. Multiple cgroupfs mounts may reference the same
+// hierarchy.
+//
+// +stateify savable
+type hierarchy struct {
+ id uint32
+ // These are a subset of the controllers in CgroupRegistry.controllers,
+ // grouped here by hierarchy for conveninent lookup.
+ controllers map[CgroupControllerType]CgroupController
+ // fs is not owned by hierarchy. The FS is responsible for unregistering the
+ // hierarchy on destruction, which removes this association.
+ fs *vfs.Filesystem
+}
+
+func (h *hierarchy) match(ctypes []CgroupControllerType) bool {
+ if len(ctypes) != len(h.controllers) {
+ return false
+ }
+ for _, ty := range ctypes {
+ if _, ok := h.controllers[ty]; !ok {
+ return false
+ }
+ }
+ return true
+}
+
+// CgroupRegistry tracks the active set of cgroup controllers on the system.
+//
+// +stateify savable
+type CgroupRegistry struct {
+ // lastHierarchyID is the id of the last allocated cgroup hierarchy. Valid
+ // ids are from 1 to math.MaxUint32. Must be accessed through atomic ops.
+ //
+ lastHierarchyID uint32
+
+ mu sync.Mutex `state:"nosave"`
+
+ // controllers is the set of currently known cgroup controllers on the
+ // system. Protected by mu.
+ //
+ // +checklocks:mu
+ controllers map[CgroupControllerType]CgroupController
+
+ // hierarchies is the active set of cgroup hierarchies. Protected by mu.
+ //
+ // +checklocks:mu
+ hierarchies map[uint32]hierarchy
+}
+
+func newCgroupRegistry() *CgroupRegistry {
+ return &CgroupRegistry{
+ controllers: make(map[CgroupControllerType]CgroupController),
+ hierarchies: make(map[uint32]hierarchy),
+ }
+}
+
+// nextHierarchyID returns a newly allocated, unique hierarchy ID.
+func (r *CgroupRegistry) nextHierarchyID() (uint32, error) {
+ if hid := atomic.AddUint32(&r.lastHierarchyID, 1); hid != 0 {
+ return hid, nil
+ }
+ return InvalidCgroupHierarchyID, fmt.Errorf("cgroup hierarchy ID overflow")
+}
+
+// FindHierarchy returns a cgroup filesystem containing exactly the set of
+// controllers named in names. If no such FS is found, FindHierarchy return
+// nil. FindHierarchy takes a reference on the returned FS, which is transferred
+// to the caller.
+func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Filesystem {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ for _, h := range r.hierarchies {
+ if h.match(ctypes) {
+ h.fs.IncRef()
+ return h.fs
+ }
+ }
+
+ return nil
+}
+
+// Register registers the provided set of controllers with the registry as a new
+// hierarchy. If any controller is already registered, the function returns an
+// error without modifying the registry. The hierarchy can be later referenced
+// by the returned id.
+func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if len(cs) == 0 {
+ return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers")
+ }
+
+ for _, c := range cs {
+ if _, ok := r.controllers[c.Type()]; ok {
+ return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy")
+ }
+ }
+
+ hid, err := r.nextHierarchyID()
+ if err != nil {
+ return hid, err
+ }
+
+ h := hierarchy{
+ id: hid,
+ controllers: make(map[CgroupControllerType]CgroupController),
+ fs: cs[0].Filesystem(),
+ }
+ for _, c := range cs {
+ n := c.Type()
+ r.controllers[n] = c
+ h.controllers[n] = c
+ }
+ r.hierarchies[hid] = h
+ return hid, nil
+}
+
+// Unregister removes a previously registered hierarchy from the registry. If
+// the controller was not previously registered, Unregister is a no-op.
+func (r *CgroupRegistry) Unregister(hid uint32) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if h, ok := r.hierarchies[hid]; ok {
+ for name, _ := range h.controllers {
+ delete(r.controllers, name)
+ }
+ delete(r.hierarchies, hid)
+ }
+}
+
+// computeInitialGroups takes a reference on each of the returned cgroups. The
+// caller takes ownership of this returned reference.
+func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[Cgroup]struct{} {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ ctlSet := make(map[CgroupControllerType]CgroupController)
+ cgset := make(map[Cgroup]struct{})
+
+ // Remember controllers from the inherited cgroups set...
+ for cg, _ := range inherit {
+ cg.IncRef() // Ref transferred to caller.
+ for _, ctl := range cg.Controllers() {
+ ctlSet[ctl.Type()] = ctl
+ cgset[cg] = struct{}{}
+ }
+ }
+
+ // ... and add the root cgroups of all the missing controllers.
+ for name, ctl := range r.controllers {
+ if _, ok := ctlSet[name]; !ok {
+ cg := ctl.RootCgroup()
+ cg.IncRef() // Ref transferred to caller.
+ cgset[cg] = struct{}{}
+ }
+ }
+ return cgset
+}
+
+// GenerateProcCgroups writes the contents of /proc/cgroups to buf.
+func (r *CgroupRegistry) GenerateProcCgroups(buf *bytes.Buffer) {
+ r.mu.Lock()
+ entries := make([]string, 0, len(r.controllers))
+ for _, c := range r.controllers {
+ en := 0
+ if c.Enabled() {
+ en = 1
+ }
+ entries = append(entries, fmt.Sprintf("%s\t%d\t%d\t%d\n", c.Type(), c.HierarchyID(), c.NumCgroups(), en))
+ }
+ r.mu.Unlock()
+
+ sort.Strings(entries)
+ fmt.Fprint(buf, "#subsys_name\thierarchy\tnum_cgroups\tenabled\n")
+ for _, e := range entries {
+ fmt.Fprint(buf, e)
+ }
+}
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 7ecbd29ab..564c3d42e 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fdnotifier",
+ "//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 2aca02fd5..4466fbc9d 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -186,7 +187,7 @@ func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) erro
e.wq.Notify(waiter.WritableEvents)
var buf [8]byte
- usermem.ByteOrder.PutUint64(buf[:], val)
+ hostarch.ByteOrder.PutUint64(buf[:], val)
_, err := dst.CopyOut(ctx, buf[:])
return err
}
@@ -194,7 +195,7 @@ func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) erro
// Must be called with e.mu locked.
func (e *EventOperations) hostWrite(val uint64) error {
var buf [8]byte
- usermem.ByteOrder.PutUint64(buf[:], val)
+ hostarch.ByteOrder.PutUint64(buf[:], val)
_, err := unix.Write(e.hostfd, buf[:])
if err == unix.EWOULDBLOCK {
return syserror.ErrWouldBlock
@@ -207,7 +208,7 @@ func (e *EventOperations) write(ctx context.Context, src usermem.IOSequence) err
if _, err := src.CopyIn(ctx, buf[:]); err != nil {
return err
}
- val := usermem.ByteOrder.Uint64(buf[:])
+ val := hostarch.ByteOrder.Uint64(buf[:])
return e.Signal(val)
}
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 041e3d4ca..a75686cf3 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -37,6 +37,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/sentry/memmap",
"//pkg/sync",
@@ -52,8 +53,8 @@ go_test(
library = ":futex",
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sync",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index e4dcc4d40..0427cf3f4 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -20,10 +20,10 @@ package futex
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// KeyKind indicates the type of a Key.
@@ -83,8 +83,8 @@ func (k *Key) clone() Key {
}
// Preconditions: k.Kind == KindPrivate or KindSharedPrivate.
-func (k *Key) addr() usermem.Addr {
- return usermem.Addr(k.Offset)
+func (k *Key) addr() hostarch.Addr {
+ return hostarch.Addr(k.Offset)
}
// matches returns true if a wakeup on k2 should wake a waiter waiting on k.
@@ -97,14 +97,14 @@ func (k *Key) matches(k2 *Key) bool {
type Target interface {
context.Context
- // SwapUint32 gives access to usermem.IO.SwapUint32.
- SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
+ // SwapUint32 gives access to hostarch.IO.SwapUint32.
+ SwapUint32(addr hostarch.Addr, new uint32) (uint32, error)
- // CompareAndSwap gives access to usermem.IO.CompareAndSwapUint32.
- CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error)
+ // CompareAndSwap gives access to hostarch.IO.CompareAndSwapUint32.
+ CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error)
- // LoadUint32 gives access to usermem.IO.LoadUint32.
- LoadUint32(addr usermem.Addr) (uint32, error)
+ // LoadUint32 gives access to hostarch.IO.LoadUint32.
+ LoadUint32(addr hostarch.Addr) (uint32, error)
// GetSharedKey returns a Key with kind KindSharedPrivate or
// KindSharedMappable corresponding to the memory mapped at address addr.
@@ -112,11 +112,11 @@ type Target interface {
// If GetSharedKey returns a Key with a non-nil MappingIdentity, a
// reference is held on the MappingIdentity, which must be dropped by the
// caller when the Key is no longer in use.
- GetSharedKey(addr usermem.Addr) (Key, error)
+ GetSharedKey(addr hostarch.Addr) (Key, error)
}
// check performs a basic equality check on the given address.
-func check(t Target, addr usermem.Addr, val uint32) error {
+func check(t Target, addr hostarch.Addr, val uint32) error {
cur, err := t.LoadUint32(addr)
if err != nil {
return err
@@ -128,7 +128,7 @@ func check(t Target, addr usermem.Addr, val uint32) error {
}
// atomicOp performs a complex operation on the given address.
-func atomicOp(t Target, addr usermem.Addr, opIn uint32) (bool, error) {
+func atomicOp(t Target, addr hostarch.Addr, opIn uint32) (bool, error) {
opType := (opIn >> 28) & 0xf
cmp := (opIn >> 24) & 0xf
opArg := (opIn >> 12) & 0xfff
@@ -328,7 +328,7 @@ const (
)
// getKey returns a Key representing address addr in c.
-func getKey(t Target, addr usermem.Addr, private bool) (Key, error) {
+func getKey(t Target, addr hostarch.Addr, private bool) (Key, error) {
// Ensure the address is aligned.
// It must be a DWORD boundary.
if addr&0x3 != 0 {
@@ -341,7 +341,7 @@ func getKey(t Target, addr usermem.Addr, private bool) (Key, error) {
}
// bucketIndexForAddr returns the index into Manager.buckets for addr.
-func bucketIndexForAddr(addr usermem.Addr) uintptr {
+func bucketIndexForAddr(addr hostarch.Addr) uintptr {
// - The bottom 2 bits of addr must be 0, per getKey.
//
// - On amd64, the top 16 bits of addr (bits 48-63) must be equal to bit 47
@@ -448,7 +448,7 @@ func (m *Manager) lockBuckets(k1, k2 *Key) (*bucket, *bucket) {
// Wake wakes up to n waiters matching the bitmask on the given addr.
// The number of waiters woken is returned.
-func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32, n int) (int, error) {
+func (m *Manager) Wake(t Target, addr hostarch.Addr, private bool, bitmask uint32, n int) (int, error) {
// This function is very hot; avoid defer.
k, err := getKey(t, addr, private)
if err != nil {
@@ -463,7 +463,7 @@ func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32
return r, nil
}
-func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, checkval bool, val uint32, nwake int, nreq int) (int, error) {
+func (m *Manager) doRequeue(t Target, addr, naddr hostarch.Addr, private bool, checkval bool, val uint32, nwake int, nreq int) (int, error) {
k1, err := getKey(t, addr, private)
if err != nil {
return 0, err
@@ -498,14 +498,14 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch
// Requeue wakes up to nwake waiters on the given addr, and unconditionally
// requeues up to nreq waiters on naddr.
-func (m *Manager) Requeue(t Target, addr, naddr usermem.Addr, private bool, nwake int, nreq int) (int, error) {
+func (m *Manager) Requeue(t Target, addr, naddr hostarch.Addr, private bool, nwake int, nreq int) (int, error) {
return m.doRequeue(t, addr, naddr, private, false, 0, nwake, nreq)
}
// RequeueCmp atomically checks that the addr contains val (via the Target),
// wakes up to nwake waiters on addr and then unconditionally requeues nreq
// waiters on naddr.
-func (m *Manager) RequeueCmp(t Target, addr, naddr usermem.Addr, private bool, val uint32, nwake int, nreq int) (int, error) {
+func (m *Manager) RequeueCmp(t Target, addr, naddr hostarch.Addr, private bool, val uint32, nwake int, nreq int) (int, error) {
return m.doRequeue(t, addr, naddr, private, true, val, nwake, nreq)
}
@@ -513,7 +513,7 @@ func (m *Manager) RequeueCmp(t Target, addr, naddr usermem.Addr, private bool, v
// waiters unconditionally from addr1, and, based on the original value at addr2
// and a comparison encoded in op, wakes up to nwake2 waiters from addr2.
// It returns the total number of waiters woken.
-func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwake1 int, nwake2 int, op uint32) (int, error) {
+func (m *Manager) WakeOp(t Target, addr1, addr2 hostarch.Addr, private bool, nwake1 int, nwake2 int, op uint32) (int, error) {
k1, err := getKey(t, addr1, private)
if err != nil {
return 0, err
@@ -553,7 +553,7 @@ func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwak
// enqueues w to be woken by a send to w.C. If WaitPrepare returns nil, the
// Waiter must be subsequently removed by calling WaitComplete, whether or not
// a wakeup is received on w.C.
-func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) error {
+func (m *Manager) WaitPrepare(w *Waiter, t Target, addr hostarch.Addr, private bool, val uint32, bitmask uint32) error {
k, err := getKey(t, addr, private)
if err != nil {
return err
@@ -631,7 +631,7 @@ func (m *Manager) WaitComplete(w *Waiter, t Target) {
// FUTEX_OWNER_DIED is only set by the Linux when robust lists are in use (see
// exit_robust_list()). Given we don't support robust lists, although handled
// below, it's never set.
-func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, private, try bool) (bool, error) {
+func (m *Manager) LockPI(w *Waiter, t Target, addr hostarch.Addr, tid uint32, private, try bool) (bool, error) {
k, err := getKey(t, addr, private)
if err != nil {
return false, err
@@ -663,7 +663,7 @@ func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, pri
return success, nil
}
-func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint32, b *bucket, try bool) (bool, error) {
+func (m *Manager) lockPILocked(w *Waiter, t Target, addr hostarch.Addr, tid uint32, b *bucket, try bool) (bool, error) {
for {
cur, err := t.LoadUint32(addr)
if err != nil {
@@ -724,7 +724,7 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3
// The address provided must contain the caller's TID. If there are waiters,
// TID of the next waiter (FIFO) is set to the given address, and the waiter
// woken up. If there are no waiters, 0 is set to the address.
-func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
+func (m *Manager) UnlockPI(t Target, addr hostarch.Addr, tid uint32, private bool) error {
k, err := getKey(t, addr, private)
if err != nil {
return err
@@ -738,7 +738,7 @@ func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool
return err
}
-func (m *Manager) unlockPILocked(t Target, addr usermem.Addr, tid uint32, b *bucket, key *Key) error {
+func (m *Manager) unlockPILocked(t Target, addr hostarch.Addr, tid uint32, b *bucket, key *Key) error {
cur, err := t.LoadUint32(addr)
if err != nil {
return err
diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go
index ba7f95d8a..deba44e5c 100644
--- a/pkg/sentry/kernel/futex/futex_test.go
+++ b/pkg/sentry/kernel/futex/futex_test.go
@@ -23,8 +23,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// testData implements the Target interface, and allows us to
@@ -43,23 +43,23 @@ func newTestData(size uint) testData {
}
}
-func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+func (t testData) SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) {
val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), new)
return val, nil
}
-func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+func (t testData) CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) {
if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), old, new) {
return old, nil
}
return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil
}
-func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) {
+func (t testData) LoadUint32(addr hostarch.Addr) (uint32, error) {
return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil
}
-func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) {
+func (t testData) GetSharedKey(addr hostarch.Addr) (Key, error) {
return Key{
Kind: KindSharedMappable,
Offset: uint64(addr),
@@ -73,7 +73,7 @@ func futexKind(private bool) string {
return "shared"
}
-func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr usermem.Addr, private bool, val uint32, bitmask uint32) *Waiter {
+func newPreparedTestWaiter(t *testing.T, m *Manager, ta Target, addr hostarch.Addr, private bool, val uint32, bitmask uint32) *Waiter {
w := NewWaiter()
if err := m.WaitPrepare(w, ta, addr, private, val, bitmask); err != nil {
t.Fatalf("WaitPrepare failed: %v", err)
@@ -463,12 +463,12 @@ const (
// Beyond being used as a Locker, this is a simple mechanism for
// changing the underlying values for simpler tests.
type testMutex struct {
- a usermem.Addr
+ a hostarch.Addr
d testData
m *Manager
}
-func newTestMutex(addr usermem.Addr, d testData, m *Manager) *testMutex {
+func newTestMutex(addr hostarch.Addr, d testData, m *Manager) *testMutex {
return &testMutex{a: addr, d: d, m: m}
}
diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go
index 4fcdfc541..4b943106b 100644
--- a/pkg/sentry/kernel/kcov.go
+++ b/pkg/sentry/kernel/kcov.go
@@ -22,13 +22,13 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/coverage"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// kcovAreaSizeMax is the maximum number of uint64 entries allowed in the kcov
@@ -130,7 +130,7 @@ func (kcov *Kcov) InitTrace(size uint64) error {
// To simplify all the logic around mapping, we require that the length of the
// shared region is a multiple of the system page size.
- if (8*size)&(usermem.PageSize-1) != 0 {
+ if (8*size)&(hostarch.PageSize-1) != 0 {
return syserror.EINVAL
}
@@ -286,7 +286,7 @@ func (rw *kcovReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
}
// Get internal mappings.
- bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Read)
+ bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, hostarch.Read)
if err != nil {
return 0, err
}
@@ -314,7 +314,7 @@ func (rw *kcovReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
}
// Get internal mapping.
- bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Write)
+ bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, hostarch.Write)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 43065b45a..e6e9da898 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -294,6 +294,11 @@ type Kernel struct {
// YAMAPtraceScope is the current level of YAMA ptrace restrictions.
YAMAPtraceScope int32
+
+ // cgroupRegistry contains the set of active cgroup controllers on the
+ // system. It is controller by cgroupfs. Nil if cgroupfs is unavailable on
+ // the system.
+ cgroupRegistry *CgroupRegistry
}
// InitKernelArgs holds arguments to Init.
@@ -438,6 +443,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.socketMount = socketMount
k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord)
+
+ k.cgroupRegistry = newCgroupRegistry()
}
return nil
}
@@ -1815,6 +1822,11 @@ func (k *Kernel) SocketMount() *vfs.Mount {
return k.socketMount
}
+// CgroupRegistry returns the cgroup registry.
+func (k *Kernel) CgroupRegistry() *CgroupRegistry {
+ return k.cgroupRegistry
+}
+
// Release releases resources owned by k.
//
// Precondition: This should only be called after the kernel is fully
@@ -1831,3 +1843,43 @@ func (k *Kernel) Release() {
k.timekeeper.Destroy()
k.vdso.Release(ctx)
}
+
+// PopulateNewCgroupHierarchy moves all tasks into a newly created cgroup
+// hierarchy.
+//
+// Precondition: root must be a new cgroup with no tasks. This implies the
+// controllers for root are also new and currently manage no task, which in turn
+// implies the new cgroup can be populated without migrating tasks between
+// cgroups.
+func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) {
+ k.tasks.mu.RLock()
+ k.tasks.forEachTaskLocked(func(t *Task) {
+ if t.exitState != TaskExitNone {
+ return
+ }
+ t.mu.Lock()
+ t.enterCgroupLocked(root)
+ t.mu.Unlock()
+ })
+ k.tasks.mu.RUnlock()
+}
+
+// ReleaseCgroupHierarchy moves all tasks out of all cgroups belonging to the
+// hierarchy with the provided id. This is intended for use during hierarchy
+// teardown, as otherwise the tasks would be orphaned w.r.t to some controllers.
+func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) {
+ k.tasks.mu.RLock()
+ k.tasks.forEachTaskLocked(func(t *Task) {
+ if t.exitState != TaskExitNone {
+ return
+ }
+ t.mu.Lock()
+ for cg, _ := range t.cgroups {
+ if cg.HierarchyID() == hid {
+ t.leaveCgroupLocked(cg)
+ }
+ }
+ t.mu.Unlock()
+ })
+ k.tasks.mu.RUnlock()
+}
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index beba6d97d..34c617b08 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/amutex",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index d004f2357..06769931a 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -22,18 +22,18 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
// MinimumPipeSize is a hard limit of the minimum size of a pipe.
// It corresponds to fs/pipe.c:pipe_min_size.
- MinimumPipeSize = usermem.PageSize
+ MinimumPipeSize = hostarch.PageSize
// MaximumPipeSize is a hard limit on the maximum size of a pipe.
// It corresponds to fs/pipe.c:pipe_max_size.
@@ -41,7 +41,7 @@ const (
// DefaultPipeSize is the system-wide default size of a pipe in bytes.
// It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS.
- DefaultPipeSize = 16 * usermem.PageSize
+ DefaultPipeSize = 16 * hostarch.PageSize
// atomicIOBytes is the maximum number of bytes that the pipe will
// guarantee atomic reads or writes atomically.
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index e524afad5..95b948edb 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -17,6 +17,7 @@ package pipe
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -274,7 +275,7 @@ func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescripti
}
src := usermem.IOSequence{
IO: fd,
- Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}),
+ Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(count)}),
}
var (
@@ -302,7 +303,7 @@ func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescripti
func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescription, off, count int64) (int64, error) {
dst := usermem.IOSequence{
IO: fd,
- Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}),
+ Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(count)}),
}
var (
@@ -328,7 +329,7 @@ func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescript
// fd.pipe.Notify(waiter.WritableEvents) after the read is completed.
//
// Preconditions: fd.pipe.mu must be locked.
-func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
n, err := fd.pipe.peekLocked(int64(len(dst)), func(srcs safemem.BlockSeq) (uint64, error) {
return safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), srcs)
})
@@ -340,7 +341,7 @@ func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte,
// is completed.
//
// Preconditions: fd.pipe.mu must be locked.
-func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts usermem.IOOpts) (int, error) {
n, err := fd.pipe.writeLocked(int64(len(src)), func(dsts safemem.BlockSeq) (uint64, error) {
return safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src)))
})
@@ -350,7 +351,7 @@ func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte,
// ZeroOut implements usermem.IO.ZeroOut.
//
// Preconditions: fd.pipe.mu must be locked.
-func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
n, err := fd.pipe.writeLocked(toZero, func(dsts safemem.BlockSeq) (uint64, error) {
return safemem.ZeroSeq(dsts)
})
@@ -362,7 +363,7 @@ func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int6
// fd.pipe.Notify(waiter.WritableEvents) after the read is completed.
//
// Preconditions: fd.pipe.mu must be locked.
-func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
return fd.pipe.peekLocked(ars.NumBytes(), func(srcs safemem.BlockSeq) (uint64, error) {
return dst.WriteFromBlocks(srcs)
})
@@ -373,25 +374,25 @@ func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst
// is completed.
//
// Preconditions: fd.pipe.mu must be locked.
-func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
return fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) {
return src.ReadToBlocks(dsts)
})
}
// SwapUint32 implements usermem.IO.SwapUint32.
-func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
// How did a pipe get passed as the virtual address space to futex(2)?
panic("VFSPipeFD.SwapUint32 called unexpectedly")
}
// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
-func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly")
}
// LoadUint32 implements usermem.IO.LoadUint32.
-func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr hostarch.Addr, opts usermem.IOOpts) (uint32, error) {
panic("VFSPipeFD.LoadUint32 called unexpectedly")
}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index f5a60e749..57c7659e7 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -1011,7 +1012,7 @@ func (t *Task) ptraceSetOptionsLocked(opts uintptr) error {
}
// Ptrace implements the ptrace system call.
-func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
+func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error {
// PTRACE_TRACEME ignores all other arguments.
if req == linux.PTRACE_TRACEME {
return t.ptraceTraceme()
@@ -1190,7 +1191,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
panic(fmt.Sprintf("%#x + %#x overflows. Invalid reg size > %#x", ar.Start, n, ar.Length()))
}
ar.End = end
- return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar))
+ return t.CopyOutIovecs(data, hostarch.AddrRangeSeqOf(ar))
case linux.PTRACE_SETREGSET:
ars, err := t.CopyInIovecs(data, 1)
@@ -1214,8 +1215,8 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
return err
}
t.p.FullStateChanged()
- ar.End -= usermem.Addr(n)
- return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar))
+ ar.End -= hostarch.Addr(n)
+ return t.CopyOutIovecs(data, hostarch.AddrRangeSeqOf(ar))
case linux.PTRACE_GETSIGINFO:
t.tg.pidns.owner.mu.RLock()
@@ -1267,7 +1268,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
case linux.PTRACE_GETEVENTMSG:
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
- _, err := primitive.CopyUint64Out(t, usermem.Addr(data), target.ptraceEventMsg)
+ _, err := primitive.CopyUint64Out(t, hostarch.Addr(data), target.ptraceEventMsg)
return err
// PEEKSIGINFO is unimplemented but seems to have no users anywhere.
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
index 7aea3dcd8..5ae05b5c3 100644
--- a/pkg/sentry/kernel/ptrace_amd64.go
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -18,12 +18,13 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
// ptraceArch implements arch-specific ptrace commands.
-func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error {
+func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) error {
switch req {
case linux.PTRACE_PEEKUSR: // aka PTRACE_PEEKUSER
n, err := target.Arch().PtracePeekUser(uintptr(addr))
diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go
index d971b96b3..46dd84cbc 100644
--- a/pkg/sentry/kernel/ptrace_arm64.go
+++ b/pkg/sentry/kernel/ptrace_arm64.go
@@ -17,11 +17,11 @@
package kernel
import (
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// ptraceArch implements arch-specific ptrace commands.
-func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) error {
+func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) error {
return syserror.EIO
}
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 2a9023fdf..4bc5bca44 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -43,8 +44,8 @@ type OldRSeqCriticalRegion struct {
// application handler while its instruction pointer is in CriticalSection,
// set the instruction pointer to Restart and application register r10 (on
// amd64) to the former instruction pointer.
- CriticalSection usermem.AddrRange
- Restart usermem.Addr
+ CriticalSection hostarch.AddrRange
+ Restart hostarch.Addr
}
// RSeqAvailable returns true if t supports (old and new) restartable sequences.
@@ -55,7 +56,7 @@ func (t *Task) RSeqAvailable() bool {
// SetRSeq registers addr as this thread's rseq structure.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error {
+func (t *Task) SetRSeq(addr hostarch.Addr, length, signature uint32) error {
if t.rseqAddr != 0 {
if t.rseqAddr != addr {
return syserror.EINVAL
@@ -100,7 +101,7 @@ func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error {
// ClearRSeq unregisters addr as this thread's rseq structure.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error {
+func (t *Task) ClearRSeq(addr hostarch.Addr, length, signature uint32) error {
if t.rseqAddr == 0 {
return syserror.EINVAL
}
@@ -166,7 +167,7 @@ func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error {
// CPU number.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) OldRSeqCPUAddr() usermem.Addr {
+func (t *Task) OldRSeqCPUAddr() hostarch.Addr {
return t.oldRSeqCPUAddr
}
@@ -177,7 +178,7 @@ func (t *Task) OldRSeqCPUAddr() usermem.Addr {
// * t.RSeqAvailable() == true.
// * The caller must be running on the task goroutine.
// * t's AddressSpace must be active.
-func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error {
+func (t *Task) SetOldRSeqCPUAddr(addr hostarch.Addr) error {
t.oldRSeqCPUAddr = addr
// Check that addr is writable.
@@ -221,7 +222,7 @@ func (t *Task) oldRSeqCopyOutCPU() error {
}
buf := t.CopyScratchBuffer(4)
- usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
+ hostarch.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
_, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf)
return err
}
@@ -236,8 +237,8 @@ func (t *Task) rseqCopyOutCPU() error {
buf := t.CopyScratchBuffer(8)
// CPUIDStart and CPUID are the first two fields in linux.RSeq.
- usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart
- usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID
+ hostarch.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart
+ hostarch.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID
// N.B. This write is not atomic, but since this occurs on the task
// goroutine then as long as userspace uses a single-instruction read
// it can't see an invalid value.
@@ -251,8 +252,8 @@ func (t *Task) rseqCopyOutCPU() error {
func (t *Task) rseqClearCPU() error {
buf := t.CopyScratchBuffer(8)
// CPUIDStart and CPUID are the first two fields in linux.RSeq.
- usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart
- usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID
+ hostarch.ByteOrder.PutUint32(buf, 0) // CPUIDStart
+ hostarch.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID
// N.B. This write is not atomic, but since this occurs on the task
// goroutine then as long as userspace uses a single-instruction read
// it can't see an invalid value.
@@ -305,7 +306,7 @@ func (t *Task) rseqAddrInterrupt() {
return
}
- critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf))
+ critAddr := hostarch.Addr(hostarch.ByteOrder.Uint64(buf))
if critAddr == 0 {
return
}
@@ -325,7 +326,7 @@ func (t *Task) rseqAddrInterrupt() {
return
}
- start := usermem.Addr(cs.Start)
+ start := hostarch.Addr(cs.Start)
critRange, ok := start.ToRange(cs.PostCommitOffset)
if !ok {
t.Debugf("Invalid start and offset in %+v", cs)
@@ -334,7 +335,7 @@ func (t *Task) rseqAddrInterrupt() {
return
}
- abort := usermem.Addr(cs.Abort)
+ abort := hostarch.Addr(cs.Abort)
if critRange.Contains(abort) {
t.Debugf("Abort in critical section in %+v", cs)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
@@ -353,7 +354,7 @@ func (t *Task) rseqAddrInterrupt() {
return
}
- sig := usermem.ByteOrder.Uint32(buf)
+ sig := hostarch.ByteOrder.Uint32(buf)
if sig != t.rseqSignature {
t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
@@ -376,7 +377,7 @@ func (t *Task) rseqAddrInterrupt() {
}
// Finally we can actually decide whether or not to restart.
- if !critRange.Contains(usermem.Addr(t.Arch().IP())) {
+ if !critRange.Contains(hostarch.Addr(t.Arch().IP())) {
return
}
@@ -386,7 +387,7 @@ func (t *Task) rseqAddrInterrupt() {
// Preconditions: The caller must be running on the task goroutine.
func (t *Task) oldRSeqInterrupt() {
r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
- if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) {
+ if ip := t.Arch().IP(); r.CriticalSection.Contains(hostarch.Addr(ip)) {
t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart)
t.Arch().SetIP(uintptr(r.Restart))
t.Arch().SetOldRSeqInterruptedIP(ip)
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index 8163a6132..a95e174a2 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -18,9 +18,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
const maxSyscallFilterInstructions = 1 << 15
@@ -35,11 +35,11 @@ func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input {
return bpf.InputBytes{
Data: buf,
// Go-marshal always uses the native byte order.
- Order: usermem.ByteOrder,
+ Order: hostarch.ByteOrder,
}
}
-func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo {
+func seccompSiginfo(t *Task, errno, sysno int32, ip hostarch.Addr) *arch.SignalInfo {
si := &arch.SignalInfo{
Signo: int32(linux.SIGSYS),
Errno: errno,
@@ -56,7 +56,7 @@ func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalIn
// in because vsyscalls do not use the values in t.Arch().)
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip usermem.Addr) linux.BPFAction {
+func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip hostarch.Addr) linux.BPFAction {
result := linux.BPFAction(t.evaluateSyscallFilters(sysno, args, ip))
action := result & linux.SECCOMP_RET_ACTION
switch action {
@@ -102,7 +102,7 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u
return action
}
-func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 {
+func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip hostarch.Addr) uint32 {
data := linux.SeccompData{
Nr: sysno,
Arch: t.image.st.AuditNumber,
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index 073e14507..1c3c0794f 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -28,6 +28,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 92d60ba78..a73f1bdca 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -38,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -47,7 +48,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Key represents a shm segment key. Analogous to a file name.
@@ -197,13 +197,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
}
var sizeAligned uint64
- if val, ok := usermem.Addr(size).RoundUp(); ok {
+ if val, ok := hostarch.Addr(size).RoundUp(); ok {
sizeAligned = uint64(val)
} else {
return nil, syserror.EINVAL
}
- if numPages := sizeAligned / usermem.PageSize; r.totalPages+numPages > linux.SHMALL {
+ if numPages := sizeAligned / hostarch.PageSize; r.totalPages+numPages > linux.SHMALL {
// "... allocating a segment of the requested size would cause the
// system to exceed the system-wide limit on shared memory (SHMALL)."
// - man shmget(2)
@@ -232,7 +232,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi
panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
}
- effectiveSize := uint64(usermem.Addr(size).MustRoundUp())
+ effectiveSize := uint64(hostarch.Addr(size).MustRoundUp())
fr, err := mfp.MemoryFile().Allocate(effectiveSize, usage.Anonymous)
if err != nil {
return nil, err
@@ -267,7 +267,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi
r.shms[id] = shm
r.keysToShms[key] = shm
- r.totalPages += effectiveSize / usermem.PageSize
+ r.totalPages += effectiveSize / hostarch.PageSize
return shm, nil
}
@@ -318,7 +318,7 @@ func (r *Registry) remove(s *Shm) {
}
delete(r.shms, s.ID)
- r.totalPages -= s.effectiveSize / usermem.PageSize
+ r.totalPages -= s.effectiveSize / hostarch.PageSize
}
// Release drops the self-reference of each active shm segment in the registry.
@@ -386,7 +386,7 @@ type Shm struct {
// effectiveSize of the segment, rounding up to the next page
// boundary. Immutable.
//
- // Invariant: effectiveSize must be a multiple of usermem.PageSize.
+ // Invariant: effectiveSize must be a multiple of hostarch.PageSize.
effectiveSize uint64
// fr is the offset into mfp.MemoryFile() that backs this contents of this
@@ -467,7 +467,7 @@ func (s *Shm) Msync(context.Context, memmap.MappableRange) error {
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) error {
+func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch.AddrRange, _ uint64, _ bool) error {
s.mu.Lock()
defer s.mu.Unlock()
s.attachTime = ktime.NowFromContext(ctx)
@@ -482,7 +482,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.A
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) {
+func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch.AddrRange, _ uint64, _ bool) {
s.mu.Lock()
defer s.mu.Unlock()
// RemoveMapping may be called during task exit, when ctx
@@ -503,12 +503,12 @@ func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ userme
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (*Shm) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error {
+func (*Shm) CopyMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, hostarch.AddrRange, uint64, bool) error {
return nil
}
// Translate implements memmap.Mappable.Translate.
-func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
var err error
if required.End > s.fr.Length() {
err = &memmap.BusError{syserror.EFAULT}
@@ -519,7 +519,7 @@ func (s *Shm) Translate(ctx context.Context, required, optional memmap.MappableR
Source: source,
File: s.mfp.MemoryFile(),
Offset: s.fr.Start + source.Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
},
}, err
}
@@ -543,7 +543,7 @@ type AttachOpts struct {
//
// Postconditions: The returned MMapOpts are valid only as long as a reference
// continues to be held on s.
-func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
+func (s *Shm) ConfigureAttach(ctx context.Context, addr hostarch.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.pendingDestruction && s.ReadRefs() == 0 {
@@ -565,12 +565,12 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts Attac
Offset: 0,
Addr: addr,
Fixed: opts.Remap,
- Perms: usermem.AccessType{
+ Perms: hostarch.AccessType{
Read: true,
Write: !opts.Readonly,
Execute: opts.Execute,
},
- MaxPerms: usermem.AnyAccess,
+ MaxPerms: hostarch.AnyAccess,
Mappable: s,
MappingIdentity: s,
}, nil
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 332bdb8e8..953d4310e 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -20,9 +20,9 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// maxSyscallNum is the highest supported syscall number.
@@ -243,7 +243,7 @@ type SyscallTable struct {
// Emulate is a collection of instruction addresses to emulate. The
// keys are addresses, and the values are system call numbers.
- Emulate map[usermem.Addr]uintptr
+ Emulate map[hostarch.Addr]uintptr
// The function to call in case of a missing system call.
Missing MissingFn
@@ -316,7 +316,7 @@ func (s *SyscallTable) Init() {
}
if s.Emulate == nil {
// Ensure non-nil emulate table.
- s.Emulate = make(map[usermem.Addr]uintptr)
+ s.Emulate = make(map[hostarch.Addr]uintptr)
}
max := s.MaxSysno() // Checked during RegisterSyscallTable.
@@ -359,7 +359,7 @@ func (s *SyscallTable) LookupNo(name string) (uintptr, error) {
}
// LookupEmulate looks up an emulation syscall number.
-func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
+func (s *SyscallTable) LookupEmulate(addr hostarch.Addr) (uintptr, bool) {
sysno, ok := s.Emulate[addr]
return sysno, ok
}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 36141dd09..be1371855 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -33,7 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -470,7 +470,7 @@ type Task struct {
// ThreadID to 0, and wake any futex waiters.
//
// cleartid is exclusive to the task goroutine.
- cleartid usermem.Addr
+ cleartid hostarch.Addr
// This is mostly a fake cpumask just for sched_set/getaffinity as we
// don't really control the affinity.
@@ -540,12 +540,12 @@ type Task struct {
// oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable.
//
// oldRSeqCPUAddr is exclusive to the task goroutine.
- oldRSeqCPUAddr usermem.Addr
+ oldRSeqCPUAddr hostarch.Addr
// rseqAddr is a pointer to the userspace linux.RSeq structure.
//
// rseqAddr is exclusive to the task goroutine.
- rseqAddr usermem.Addr
+ rseqAddr hostarch.Addr
// rseqSignature is the signature that the rseq abort IP must be signed
// with.
@@ -575,7 +575,7 @@ type Task struct {
// robustList is a pointer to the head of the tasks's robust futex
// list.
- robustList usermem.Addr
+ robustList hostarch.Addr
// startTime is the real time at which the task started. It is set when
// a Task is created or invokes execve(2).
@@ -587,6 +587,12 @@ type Task struct {
//
// kcov is exclusive to the task goroutine.
kcov *Kcov
+
+ // cgroups is the set of cgroups this task belongs to. This may be empty if
+ // no cgroup controllers are enabled. Protected by mu.
+ //
+ // +checklocks:mu
+ cgroups map[Cgroup]struct{}
}
func (t *Task) savePtraceTracer() *Task {
@@ -652,7 +658,7 @@ func (t *Task) Kernel() *Kernel {
// SetClearTID sets t's cleartid.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) SetClearTID(addr usermem.Addr) {
+func (t *Task) SetClearTID(addr hostarch.Addr) {
t.cleartid = addr
}
diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go
new file mode 100644
index 000000000..25d2504fa
--- /dev/null
+++ b/pkg/sentry/kernel/task_cgroup.go
@@ -0,0 +1,138 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// EnterInitialCgroups moves t into an initial set of cgroups.
+//
+// Precondition: t isn't in any cgroups yet, t.cgs is empty.
+//
+// +checklocksignore parent.mu is conditionally acquired.
+func (t *Task) EnterInitialCgroups(parent *Task) {
+ var inherit map[Cgroup]struct{}
+ if parent != nil {
+ parent.mu.Lock()
+ defer parent.mu.Unlock()
+ inherit = parent.cgroups
+ }
+ joinSet := t.k.cgroupRegistry.computeInitialGroups(inherit)
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // Transfer ownership of joinSet refs to the task's cgset.
+ t.cgroups = joinSet
+ for c, _ := range t.cgroups {
+ // Since t isn't in any cgroup yet, we can skip the check against
+ // existing cgroups.
+ c.Enter(t)
+ }
+}
+
+// EnterCgroup moves t into c.
+func (t *Task) EnterCgroup(c Cgroup) error {
+ newControllers := make(map[CgroupControllerType]struct{})
+ for _, ctl := range c.Controllers() {
+ newControllers[ctl.Type()] = struct{}{}
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ for oldCG, _ := range t.cgroups {
+ for _, oldCtl := range oldCG.Controllers() {
+ if _, ok := newControllers[oldCtl.Type()]; ok {
+ // Already in a cgroup with the same controller as one of the
+ // new ones. Requires migration between cgroups.
+ //
+ // TODO(b/183137098): Implement cgroup migration.
+ log.Warningf("Cgroup migration is not implemented")
+ return syserror.EBUSY
+ }
+ }
+ }
+
+ // No migration required.
+ t.enterCgroupLocked(c)
+
+ return nil
+}
+
+// +checklocks:t.mu
+func (t *Task) enterCgroupLocked(c Cgroup) {
+ c.IncRef()
+ t.cgroups[c] = struct{}{}
+ c.Enter(t)
+}
+
+// LeaveCgroups removes t out from all its cgroups.
+func (t *Task) LeaveCgroups() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ for c, _ := range t.cgroups {
+ t.leaveCgroupLocked(c)
+ }
+}
+
+// +checklocks:t.mu
+func (t *Task) leaveCgroupLocked(c Cgroup) {
+ c.Leave(t)
+ delete(t.cgroups, c)
+ c.decRef()
+}
+
+// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to
+// format a cgroup for display.
+type taskCgroupEntry struct {
+ hierarchyID uint32
+ controllers string
+ path string
+}
+
+// GenerateProcTaskCgroup writes the contents of /proc/<pid>/cgroup for t to buf.
+func (t *Task) GenerateProcTaskCgroup(buf *bytes.Buffer) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ cgEntries := make([]taskCgroupEntry, 0, len(t.cgroups))
+ for c, _ := range t.cgroups {
+ ctls := c.Controllers()
+ ctlNames := make([]string, 0, len(ctls))
+ for _, ctl := range ctls {
+ ctlNames = append(ctlNames, string(ctl.Type()))
+ }
+
+ cgEntries = append(cgEntries, taskCgroupEntry{
+ // Note: We're guaranteed to have at least one controller, and all
+ // controllers are guaranteed to be on the same hierarchy.
+ hierarchyID: ctls[0].HierarchyID(),
+ controllers: strings.Join(ctlNames, ","),
+ path: c.Path(),
+ })
+ }
+
+ sort.Slice(cgEntries, func(i, j int) bool { return cgEntries[i].hierarchyID > cgEntries[j].hierarchyID })
+ for _, cgE := range cgEntries {
+ fmt.Fprintf(buf, "%d:%s:%s\n", cgE.hierarchyID, cgE.controllers, cgE.path)
+ }
+}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index f305e69c0..405771f3f 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -85,12 +86,12 @@ type CloneOptions struct {
// Stack is the initial stack pointer of the new task. If Stack is 0, the
// new task will start with the same stack pointer as its parent.
- Stack usermem.Addr
+ Stack hostarch.Addr
// If SetTLS is true, set the new task's TLS (thread-local storage)
// descriptor to TLS. If SetTLS is false, TLS is ignored.
SetTLS bool
- TLS usermem.Addr
+ TLS hostarch.Addr
// If ChildClearTID is true, when the child exits, 0 is written to the
// address ChildTID in the child's memory, and if the write is successful a
@@ -101,7 +102,7 @@ type CloneOptions struct {
// Linux, failed writes are silently ignored.)
ChildClearTID bool
ChildSetTID bool
- ChildTID usermem.Addr
+ ChildTID hostarch.Addr
// If ParentSetTID is true, the child's thread ID (in the parent's PID
// namespace) is written to address ParentTID in the parent's memory. (As
@@ -112,7 +113,7 @@ type CloneOptions struct {
// and child's memory, but this is a documentation error fixed by
// 87ab04792ced ("clone.2: Fix description of CLONE_PARENT_SETTID").
ParentSetTID bool
- ParentTID usermem.Addr
+ ParentTID hostarch.Addr
// If Vfork is true, place the parent in vforkStop until the cloned task
// releases its TaskImage.
@@ -268,7 +269,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
}
tg := t.tg
- rseqAddr := usermem.Addr(0)
+ rseqAddr := hostarch.Addr(0)
rseqSignature := uint32(0)
if opts.NewThreadGroup {
if tg.mounts != nil {
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index ad59e4f60..b1af1a7ef 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -275,6 +275,10 @@ func (*runExitMain) execute(t *Task) taskRunState {
t.fsContext.DecRef(t)
t.fdTable.DecRef(t)
+ // Detach task from all cgroups. This must happen before potentially the
+ // last ref to the cgroupfs mount is dropped below.
+ t.LeaveCgroups()
+
t.mu.Lock()
if t.mountNamespaceVFS2 != nil {
t.mountNamespaceVFS2.DecRef(t)
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index 195c7da9b..4dc41b82b 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -16,6 +16,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/usermem"
@@ -30,33 +31,33 @@ func (t *Task) Futex() *futex.Manager {
}
// SwapUint32 implements futex.Target.SwapUint32.
-func (t *Task) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+func (t *Task) SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) {
return t.MemoryManager().SwapUint32(t, addr, new, usermem.IOOpts{
AddressSpaceActive: true,
})
}
// CompareAndSwapUint32 implements futex.Target.CompareAndSwapUint32.
-func (t *Task) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+func (t *Task) CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) {
return t.MemoryManager().CompareAndSwapUint32(t, addr, old, new, usermem.IOOpts{
AddressSpaceActive: true,
})
}
// LoadUint32 implements futex.Target.LoadUint32.
-func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
+func (t *Task) LoadUint32(addr hostarch.Addr) (uint32, error) {
return t.MemoryManager().LoadUint32(t, addr, usermem.IOOpts{
AddressSpaceActive: true,
})
}
// GetSharedKey implements futex.Target.GetSharedKey.
-func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
+func (t *Task) GetSharedKey(addr hostarch.Addr) (futex.Key, error) {
return t.MemoryManager().GetSharedFutexKey(t, addr)
}
// GetRobustList sets the robust futex list for the task.
-func (t *Task) GetRobustList() usermem.Addr {
+func (t *Task) GetRobustList() hostarch.Addr {
t.mu.Lock()
addr := t.robustList
t.mu.Unlock()
@@ -64,7 +65,7 @@ func (t *Task) GetRobustList() usermem.Addr {
}
// SetRobustList sets the robust futex list for the task.
-func (t *Task) SetRobustList(addr usermem.Addr) {
+func (t *Task) SetRobustList(addr hostarch.Addr) {
t.mu.Lock()
t.robustList = addr
t.mu.Unlock()
@@ -84,28 +85,28 @@ func (t *Task) exitRobustList() {
}
var rl linux.RobustListHead
- if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil {
+ if _, err := rl.CopyIn(t, hostarch.Addr(addr)); err != nil {
return
}
next := primitive.Uint64(rl.List)
done := 0
- var pendingLockAddr usermem.Addr
+ var pendingLockAddr hostarch.Addr
if rl.ListOpPending != 0 {
- pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset)
+ pendingLockAddr = hostarch.Addr(rl.ListOpPending + rl.FutexOffset)
}
// Wake up normal elements.
- for usermem.Addr(next) != addr {
+ for hostarch.Addr(next) != addr {
// We traverse to the next element of the list before we
// actually wake anything. This prevents the race where waking
// this futex causes a modification of the list.
- thisLockAddr := usermem.Addr(uint64(next) + rl.FutexOffset)
+ thisLockAddr := hostarch.Addr(uint64(next) + rl.FutexOffset)
// Try to decode the next element in the list before waking the
// current futex. But don't check the error until after we've
// woken the current futex. Linux does it in this order too
- _, nextErr := next.CopyIn(t, usermem.Addr(next))
+ _, nextErr := next.CopyIn(t, hostarch.Addr(next))
// Wakeup the current futex if it's not pending.
if thisLockAddr != pendingLockAddr {
@@ -133,7 +134,7 @@ func (t *Task) exitRobustList() {
}
// wakeRobustListOne wakes a single futex from the robust list.
-func (t *Task) wakeRobustListOne(addr usermem.Addr) {
+func (t *Task) wakeRobustListOne(addr hostarch.Addr) {
// Bit 0 in address signals PI futex.
pi := addr&1 == 1
addr = addr &^ 1
diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go
index ce5fbd299..bd5543d4e 100644
--- a/pkg/sentry/kernel/task_image.go
+++ b/pkg/sentry/kernel/task_image.go
@@ -19,12 +19,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/loader"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/usermem"
)
var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC)
@@ -129,7 +129,7 @@ func (t *Task) Stack() *arch.Stack {
return &arch.Stack{
Arch: t.Arch(),
IO: t.MemoryManager(),
- Bottom: usermem.Addr(t.Arch().Stack()),
+ Bottom: hostarch.Addr(t.Arch().Stack()),
}
}
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index c70e5e6ce..72b9a0384 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -20,6 +20,7 @@ import (
"sort"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -108,9 +109,9 @@ func (t *Task) debugDumpStack() {
return
}
t.Debugf("Stack:")
- start := usermem.Addr(t.Arch().Stack())
+ start := hostarch.Addr(t.Arch().Stack())
// Round addr down to a 16-byte boundary.
- start &= ^usermem.Addr(15)
+ start &= ^hostarch.Addr(15)
// Print 16 bytes per line, one byte at a time.
for offset := uint64(0); offset < maxStackDebugBytes; offset += 16 {
addr, ok := start.AddLength(offset)
@@ -127,7 +128,7 @@ func (t *Task) debugDumpStack() {
t.Debugf("%x: % x", addr, data[:n])
}
if err != nil {
- t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err)
+ t.Debugf("Error reading stack at address %x: %v", addr+hostarch.Addr(n), err)
break
}
}
@@ -147,9 +148,9 @@ func (t *Task) debugDumpCode() {
}
t.Debugf("Code:")
// Print code on both sides of the instruction register.
- start := usermem.Addr(t.Arch().IP()) - maxCodeDebugBytes/2
+ start := hostarch.Addr(t.Arch().IP()) - maxCodeDebugBytes/2
// Round addr down to a 16-byte boundary.
- start &= ^usermem.Addr(15)
+ start &= ^hostarch.Addr(15)
// Print 16 bytes per line, one byte at a time.
for offset := uint64(0); offset < maxCodeDebugBytes; offset += 16 {
addr, ok := start.AddLength(offset)
@@ -166,7 +167,7 @@ func (t *Task) debugDumpCode() {
t.Debugf("%x: % x", addr, data[:n])
}
if err != nil {
- t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err)
+ t.Debugf("Error reading stack at address %x: %v", addr+hostarch.Addr(n), err)
break
}
}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index 3ccecf4b6..068f25af1 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -23,13 +23,13 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/goid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// A taskRunState is a reified state in the task state machine. See README.md
@@ -148,7 +148,7 @@ func (*runApp) handleCPUIDInstruction(t *Task) error {
region := trace.StartRegion(t.traceContext, cpuidRegion)
expected := arch.CPUIDInstruction[:]
found := make([]byte, len(expected))
- _, err := t.CopyInBytes(usermem.Addr(t.Arch().IP()), found)
+ _, err := t.CopyInBytes(hostarch.Addr(t.Arch().IP()), found)
if err == nil && bytes.Equal(expected, found) {
// Skip the cpuid instruction.
t.Arch().CPUIDEmulate(t)
@@ -307,8 +307,8 @@ func (app *runApp) execute(t *Task) taskRunState {
// normally.
if at.Any() {
region := trace.StartRegion(t.traceContext, faultRegion)
- addr := usermem.Addr(info.Addr())
- err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack()))
+ addr := hostarch.Addr(info.Addr())
+ err := t.MemoryManager().HandleUserFault(t, addr, at, hostarch.Addr(t.Arch().Stack()))
region.End()
if err == nil {
// The fault was handled appropriately.
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 75af3af79..c2b9fc08f 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -23,11 +23,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -243,7 +243,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
// Are executing on the main stack,
// or the provided alternate stack?
- sp := usermem.Addr(t.Arch().Stack())
+ sp := hostarch.Addr(t.Arch().Stack())
// N.B. This is a *copy* of the alternate stack that the user's signal
// handler expects to see in its ucontext (even if it's not in use).
@@ -251,7 +251,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
if act.IsOnStack() && alt.IsEnabled() {
alt.SetOnStack()
if !alt.Contains(sp) {
- sp = usermem.Addr(alt.Top())
+ sp = hostarch.Addr(alt.Top())
}
}
@@ -652,7 +652,7 @@ func (t *Task) SignalStack() arch.SignalStack {
// onSignalStack returns true if the task is executing on the given signal stack.
func (t *Task) onSignalStack(alt arch.SignalStack) bool {
- sp := usermem.Addr(t.Arch().Stack())
+ sp := hostarch.Addr(t.Arch().Stack())
return alt.Contains(sp)
}
@@ -720,7 +720,7 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a
// CopyOutSignalAct converts the given SignalAct into an architecture-specific
// type and then copies it out to task memory.
-func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
+func (t *Task) CopyOutSignalAct(addr hostarch.Addr, s *arch.SignalAct) error {
n := t.Arch().NewSignalAct()
n.SerializeFrom(s)
_, err := n.CopyOut(t, addr)
@@ -729,7 +729,7 @@ func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
// CopyInSignalAct copies an architecture-specific sigaction type from task
// memory and then converts it into a SignalAct.
-func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
+func (t *Task) CopyInSignalAct(addr hostarch.Addr) (arch.SignalAct, error) {
n := t.Arch().NewSignalAct()
var s arch.SignalAct
if _, err := n.CopyIn(t, addr); err != nil {
@@ -741,7 +741,7 @@ func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
// CopyOutSignalStack converts the given SignalStack into an
// architecture-specific type and then copies it out to task memory.
-func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error {
+func (t *Task) CopyOutSignalStack(addr hostarch.Addr, s *arch.SignalStack) error {
n := t.Arch().NewSignalStack()
n.SerializeFrom(s)
_, err := n.CopyOut(t, addr)
@@ -750,7 +750,7 @@ func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error
// CopyInSignalStack copies an architecture-specific stack_t from task memory
// and then converts it into a SignalStack.
-func (t *Task) CopyInSignalStack(addr usermem.Addr) (arch.SignalStack, error) {
+func (t *Task) CopyInSignalStack(addr hostarch.Addr) (arch.SignalStack, error) {
n := t.Arch().NewSignalStack()
var s arch.SignalStack
if _, err := n.CopyIn(t, addr); err != nil {
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 36e1384f1..32031cd70 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// TaskConfig defines the configuration of a new Task (see below).
@@ -86,7 +86,7 @@ type TaskConfig struct {
MountNamespaceVFS2 *vfs.MountNamespace
// RSeqAddr is a pointer to the the userspace linux.RSeq structure.
- RSeqAddr usermem.Addr
+ RSeqAddr hostarch.Addr
// RSeqSignature is the signature that the rseq abort IP must be signed
// with.
@@ -151,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
rseqSignature: cfg.RSeqSignature,
futexWaiter: futex.NewWaiter(),
containerID: cfg.ContainerID,
+ cgroups: make(map[Cgroup]struct{}),
}
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
@@ -189,6 +190,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
t.parent.children[t] = struct{}{}
}
+ if VFS2Enabled {
+ t.EnterInitialCgroups(t.parent)
+ }
+
if tg.leader == nil {
// New thread group.
tg.leader = t
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index 2e84bd88a..2c658d001 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -22,12 +22,12 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application")
@@ -153,7 +153,7 @@ func (t *Task) doSyscall() taskRunState {
// Check seccomp filters. The nil check is for performance (as seccomp use
// is rare), not needed for correctness.
if t.syscallFilters.Load() != nil {
- switch r := t.checkSeccompSyscall(int32(sysno), args, usermem.Addr(t.Arch().IP())); r {
+ switch r := t.checkSeccompSyscall(int32(sysno), args, hostarch.Addr(t.Arch().IP())); r {
case linux.SECCOMP_RET_ERRNO, linux.SECCOMP_RET_TRAP:
t.Debugf("Syscall %d: denied by seccomp", sysno)
return (*runSyscallExit)(nil)
@@ -283,12 +283,12 @@ func (*runSyscallExit) execute(t *Task) taskRunState {
// doVsyscall is the entry point for a vsyscall invocation of syscall sysno, as
// indicated by an execution fault at address addr. doVsyscall returns the
// task's next run state.
-func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
+func (t *Task) doVsyscall(addr hostarch.Addr, sysno uintptr) taskRunState {
vsyscallCount.Increment()
// Grab the caller up front, to make sure there's a sensible stack.
caller := t.Arch().Native(uintptr(0))
- if _, err := caller.CopyIn(t, usermem.Addr(t.Arch().Stack())); err != nil {
+ if _, err := caller.CopyIn(t, hostarch.Addr(t.Arch().Stack())); err != nil {
t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
@@ -322,7 +322,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
}
type runVsyscallAfterPtraceEventSeccomp struct {
- addr usermem.Addr
+ addr hostarch.Addr
sysno uintptr
caller marshal.Marshallable
}
@@ -337,7 +337,7 @@ func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
// currently emulated call. ... The tracer MUST NOT modify rip or rsp." -
// Documentation/prctl/seccomp_filter.txt. On Linux, changing orig_ax or ip
// causes do_exit(SIGSYS), and changing sp is ignored.
- if (sysno != ^uintptr(0) && sysno != r.sysno) || usermem.Addr(t.Arch().IP()) != r.addr {
+ if (sysno != ^uintptr(0) && sysno != r.sysno) || hostarch.Addr(t.Arch().IP()) != r.addr {
t.PrepareExit(ExitStatus{Signo: int(linux.SIGSYS)})
return (*runExit)(nil)
}
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 94dabbcd8..fc6d9438a 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -27,7 +28,7 @@ import (
// MAX_RW_COUNT is the maximum size in bytes of a single read or write.
// Reads and writes that exceed this size may be silently truncated.
// (Linux: include/linux/fs.h:MAX_RW_COUNT)
-var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
+var MAX_RW_COUNT = int(hostarch.Addr(math.MaxInt32).RoundDown())
// Activate ensures that the task has an active address space.
func (t *Task) Activate() {
@@ -49,7 +50,7 @@ func (t *Task) Deactivate() {
// data without reflection and pass in a byte slice.
//
// This Task's AddressSpace must be active.
-func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+func (t *Task) CopyInBytes(addr hostarch.Addr, dst []byte) (int, error) {
return t.MemoryManager().CopyIn(t, addr, dst, usermem.IOOpts{
AddressSpaceActive: true,
})
@@ -59,7 +60,7 @@ func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
// data without reflection and pass in a byte slice.
//
// This Task's AddressSpace must be active.
-func (t *Task) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+func (t *Task) CopyOutBytes(addr hostarch.Addr, src []byte) (int, error) {
return t.MemoryManager().CopyOut(t, addr, src, usermem.IOOpts{
AddressSpaceActive: true,
})
@@ -70,7 +71,7 @@ func (t *Task) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
// user memory that is unmapped or not readable by the user.
//
// This Task's AddressSpace must be active.
-func (t *Task) CopyInString(addr usermem.Addr, maxlen int) (string, error) {
+func (t *Task) CopyInString(addr hostarch.Addr, maxlen int) (string, error) {
return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxlen, usermem.IOOpts{
AddressSpaceActive: true,
})
@@ -90,7 +91,7 @@ func (t *Task) CopyInString(addr usermem.Addr, maxlen int) (string, error) {
// { "abc" } => 4 (3 for length, 1 for elements)
//
// This Task's AddressSpace must be active.
-func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([]string, error) {
+func (t *Task) CopyInVector(addr hostarch.Addr, maxElemSize, maxTotalSize int) ([]string, error) {
var v []string
for {
argAddr := t.Arch().Native(0)
@@ -109,12 +110,12 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([
if maxTotalSize < thisMax {
thisMax = maxTotalSize
}
- arg, err := t.CopyInString(usermem.Addr(t.Arch().Value(argAddr)), thisMax)
+ arg, err := t.CopyInString(hostarch.Addr(t.Arch().Value(argAddr)), thisMax)
if err != nil {
return v, err
}
v = append(v, arg)
- addr += usermem.Addr(t.Arch().Width())
+ addr += hostarch.Addr(t.Arch().Width())
maxTotalSize -= len(arg) + 1
}
return v, nil
@@ -126,7 +127,7 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([
// Preconditions: Same as usermem.IO.CopyOut, plus:
// * The caller must be running on the task goroutine.
// * t's AddressSpace must be active.
-func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error {
+func (t *Task) CopyOutIovecs(addr hostarch.Addr, src hostarch.AddrRangeSeq) error {
switch t.Arch().Width() {
case 8:
const itemLen = 16
@@ -137,8 +138,8 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error
b := t.CopyScratchBuffer(itemLen)
for ; !src.IsEmpty(); src = src.Tail() {
ar := src.Head()
- usermem.ByteOrder.PutUint64(b[0:8], uint64(ar.Start))
- usermem.ByteOrder.PutUint64(b[8:16], uint64(ar.Length()))
+ hostarch.ByteOrder.PutUint64(b[0:8], uint64(ar.Start))
+ hostarch.ByteOrder.PutUint64(b[8:16], uint64(ar.Length()))
if _, err := t.CopyOutBytes(addr, b); err != nil {
return err
}
@@ -153,8 +154,8 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error
}
// CopyInIovecs copies an array of numIovecs struct iovecs from the memory
-// mapped at addr, converts them to usermem.AddrRanges, and returns them as a
-// usermem.AddrRangeSeq.
+// mapped at addr, converts them to hostarch.AddrRanges, and returns them as a
+// hostarch.AddrRangeSeq.
//
// CopyInIovecs shares the following properties with Linux's
// lib/iov_iter.c:import_iovec() => fs/read_write.c:rw_copy_check_uvector():
@@ -175,42 +176,42 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error
// Preconditions: Same as usermem.IO.CopyIn, plus:
// * The caller must be running on the task goroutine.
// * t's AddressSpace must be active.
-func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRangeSeq, error) {
+func (t *Task) CopyInIovecs(addr hostarch.Addr, numIovecs int) (hostarch.AddrRangeSeq, error) {
if numIovecs == 0 {
- return usermem.AddrRangeSeq{}, nil
+ return hostarch.AddrRangeSeq{}, nil
}
- var dst []usermem.AddrRange
+ var dst []hostarch.AddrRange
if numIovecs > 1 {
- dst = make([]usermem.AddrRange, 0, numIovecs)
+ dst = make([]hostarch.AddrRange, 0, numIovecs)
}
switch t.Arch().Width() {
case 8:
const itemLen = 16
if _, ok := addr.AddLength(uint64(numIovecs) * itemLen); !ok {
- return usermem.AddrRangeSeq{}, syserror.EFAULT
+ return hostarch.AddrRangeSeq{}, syserror.EFAULT
}
b := t.CopyScratchBuffer(itemLen)
for i := 0; i < numIovecs; i++ {
if _, err := t.CopyInBytes(addr, b); err != nil {
- return usermem.AddrRangeSeq{}, err
+ return hostarch.AddrRangeSeq{}, err
}
- base := usermem.Addr(usermem.ByteOrder.Uint64(b[0:8]))
- length := usermem.ByteOrder.Uint64(b[8:16])
+ base := hostarch.Addr(hostarch.ByteOrder.Uint64(b[0:8]))
+ length := hostarch.ByteOrder.Uint64(b[8:16])
if length > math.MaxInt64 {
- return usermem.AddrRangeSeq{}, syserror.EINVAL
+ return hostarch.AddrRangeSeq{}, syserror.EINVAL
}
ar, ok := t.MemoryManager().CheckIORange(base, int64(length))
if !ok {
- return usermem.AddrRangeSeq{}, syserror.EFAULT
+ return hostarch.AddrRangeSeq{}, syserror.EFAULT
}
if numIovecs == 1 {
// Special case to avoid allocating dst.
- return usermem.AddrRangeSeqOf(ar).TakeFirst(MAX_RW_COUNT), nil
+ return hostarch.AddrRangeSeqOf(ar).TakeFirst(MAX_RW_COUNT), nil
}
dst = append(dst, ar)
@@ -218,7 +219,7 @@ func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRange
}
default:
- return usermem.AddrRangeSeq{}, syserror.ENOSYS
+ return hostarch.AddrRangeSeq{}, syserror.ENOSYS
}
// Truncate to MAX_RW_COUNT.
@@ -226,13 +227,13 @@ func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRange
for i := range dst {
dstlen := uint64(dst[i].Length())
if rem := uint64(MAX_RW_COUNT) - total; rem < dstlen {
- dst[i].End -= usermem.Addr(dstlen - rem)
+ dst[i].End -= hostarch.Addr(dstlen - rem)
dstlen = rem
}
total += dstlen
}
- return usermem.AddrRangeSeqFromSlice(dst), nil
+ return hostarch.AddrRangeSeqFromSlice(dst), nil
}
// SingleIOSequence returns a usermem.IOSequence representing [addr,
@@ -245,7 +246,7 @@ func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRange
// write syscalls in Linux do not use import_single_range(). However they check
// access_ok() in fs/read_write.c:vfs_read/vfs_write, and overflowing address
// ranges are truncated to MAX_RW_COUNT by fs/read_write.c:rw_verify_area().)
-func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOpts) (usermem.IOSequence, error) {
+func (t *Task) SingleIOSequence(addr hostarch.Addr, length int, opts usermem.IOOpts) (usermem.IOSequence, error) {
if length > MAX_RW_COUNT {
length = MAX_RW_COUNT
}
@@ -255,7 +256,7 @@ func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOp
}
return usermem.IOSequence{
IO: t.MemoryManager(),
- Addrs: usermem.AddrRangeSeqOf(ar),
+ Addrs: hostarch.AddrRangeSeqOf(ar),
Opts: opts,
}, nil
}
@@ -267,7 +268,7 @@ func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOp
// IovecsIOSequence is analogous to Linux's lib/iov_iter.c:import_iovec().
//
// Preconditions: Same as Task.CopyInIovecs.
-func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) {
+func (t *Task) IovecsIOSequence(addr hostarch.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) {
if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
return usermem.IOSequence{}, syserror.EINVAL
}
@@ -317,7 +318,7 @@ func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) {
}
// CopyInBytes implements marshal.CopyContext.CopyInBytes.
-func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+func (cc *taskCopyContext) CopyInBytes(addr hostarch.Addr, dst []byte) (int, error) {
tmm, err := cc.getMemoryManager()
if err != nil {
return 0, err
@@ -327,7 +328,7 @@ func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, erro
}
// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
-func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+func (cc *taskCopyContext) CopyOutBytes(addr hostarch.Addr, src []byte) (int, error) {
tmm, err := cc.getMemoryManager()
if err != nil {
return 0, err
@@ -360,11 +361,11 @@ func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte {
}
// CopyInBytes implements marshal.CopyContext.CopyInBytes.
-func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+func (cc *ownTaskCopyContext) CopyInBytes(addr hostarch.Addr, dst []byte) (int, error) {
return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts)
}
// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
-func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+func (cc *ownTaskCopyContext) CopyOutBytes(addr hostarch.Addr, src []byte) (int, error) {
return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts)
}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 09d070ec8..77ad62445 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -114,6 +114,15 @@ func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) {
}
}
+// forEachTaskLocked applies f to each Task in ts.
+//
+// Preconditions: ts.mu must be locked (for reading or writing).
+func (ts *TaskSet) forEachTaskLocked(f func(t *Task)) {
+ for t := range ts.Root.tids {
+ f(t)
+ }
+}
+
// A PIDNamespace represents a PID namespace, a bimap between thread IDs and
// tasks. See the pid_namespaces(7) man page for further details.
//
diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go
index cf2f7ca72..dfc3c0719 100644
--- a/pkg/sentry/kernel/timekeeper_test.go
+++ b/pkg/sentry/kernel/timekeeper_test.go
@@ -17,12 +17,12 @@ package kernel
import (
"testing"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// mockClocks is a sentrytime.Clocks that simply returns the times in the
@@ -54,7 +54,7 @@ func (c *mockClocks) GetTime(id sentrytime.ClockID) (int64, error) {
func stateTestClocklessTimekeeper(tb testing.TB) *Timekeeper {
ctx := contexttest.Context(tb)
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
- fr, err := mfp.MemoryFile().Allocate(usermem.PageSize, usage.Anonymous)
+ fr, err := mfp.MemoryFile().Allocate(hostarch.PageSize, usage.Anonymous)
if err != nil {
tb.Fatalf("failed to allocate memory: %v", err)
}
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index 9e5c2d26f..cc0917504 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -17,10 +17,10 @@ package kernel
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/usermem"
)
// vdsoParams are the parameters exposed to the VDSO.
@@ -96,7 +96,7 @@ func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSO
// access returns a mapping of the param page.
func (v *VDSOParamPage) access() (safemem.Block, error) {
- bs, err := v.mfp.MemoryFile().MapInternal(v.fr, usermem.ReadWrite)
+ bs, err := v.mfp.MemoryFile().MapInternal(v.fr, hostarch.ReadWrite)
if err != nil {
return safemem.Block{}, err
}
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index ab074b400..ecb6603a1 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -18,6 +18,7 @@ go_library(
"//pkg/binary",
"//pkg/context",
"//pkg/cpuid",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/rand",
"//pkg/safemem",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index cd9fa4031..e92d9fdc3 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -41,7 +42,7 @@ const (
// maxTotalPhdrSize is the maximum combined size of all program
// headers. Linux limits this to one page.
- maxTotalPhdrSize = usermem.PageSize
+ maxTotalPhdrSize = hostarch.PageSize
)
var (
@@ -52,8 +53,8 @@ var (
prog64Size = int(binary.Size(elf.Prog64{}))
)
-func progFlagsAsPerms(f elf.ProgFlag) usermem.AccessType {
- var p usermem.AccessType
+func progFlagsAsPerms(f elf.ProgFlag) hostarch.AccessType {
+ var p hostarch.AccessType
if f&elf.PF_R == elf.PF_R {
p.Read = true
}
@@ -75,7 +76,7 @@ type elfInfo struct {
arch arch.Arch
// entry is the program entry point.
- entry usermem.Addr
+ entry hostarch.Addr
// phdrs are the program headers.
phdrs []elf.ProgHeader
@@ -230,7 +231,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
return elfInfo{
os: os,
arch: a,
- entry: usermem.Addr(hdr.Entry),
+ entry: hostarch.Addr(hdr.Entry),
phdrs: phdrs,
phdrOff: hdr.Phoff,
phdrSize: prog64Size,
@@ -240,9 +241,9 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
// mapSegment maps a phdr into the Task. offset is the offset to apply to
// phdr.Vaddr.
-func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
+func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset hostarch.Addr) error {
// We must make a page-aligned mapping.
- adjust := usermem.Addr(phdr.Vaddr).PageOffset()
+ adjust := hostarch.Addr(phdr.Vaddr).PageOffset()
addr, ok := offset.AddLength(phdr.Vaddr)
if !ok {
@@ -250,14 +251,14 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
ctx.Warningf("Computed segment load address overflows: %#x + %#x", phdr.Vaddr, offset)
return syserror.ENOEXEC
}
- addr -= usermem.Addr(adjust)
+ addr -= hostarch.Addr(adjust)
fileSize := phdr.Filesz + adjust
if fileSize < phdr.Filesz {
ctx.Infof("Computed segment file size overflows: %#x + %#x", phdr.Filesz, adjust)
return syserror.ENOEXEC
}
- ms, ok := usermem.Addr(fileSize).RoundUp()
+ ms, ok := hostarch.Addr(fileSize).RoundUp()
if !ok {
ctx.Infof("fileSize %#x too large", fileSize)
return syserror.ENOEXEC
@@ -281,7 +282,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
Unmap: true,
Private: true,
Perms: prot,
- MaxPerms: usermem.AnyAccess,
+ MaxPerms: hostarch.AnyAccess,
}
defer func() {
if mopts.MappingIdentity != nil {
@@ -312,7 +313,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
panic(fmt.Sprintf("zeroSize too big? %#x", uint64(zeroSize)))
}
if _, err := m.ZeroOut(ctx, zeroAddr, zeroSize, usermem.IOOpts{IgnorePermissions: true}); err != nil {
- ctx.Warningf("Failed to zero end of page [%#x, %#x): %v", zeroAddr, zeroAddr+usermem.Addr(zeroSize), err)
+ ctx.Warningf("Failed to zero end of page [%#x, %#x): %v", zeroAddr, zeroAddr+hostarch.Addr(zeroSize), err)
return err
}
}
@@ -330,7 +331,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
if !ok {
panic(fmt.Sprintf("anonymous memory doesn't fit in pre-sized range? %#x + %#x", addr, mapSize))
}
- anonSize, ok := usermem.Addr(memSize - mapSize).RoundUp()
+ anonSize, ok := hostarch.Addr(memSize - mapSize).RoundUp()
if !ok {
ctx.Infof("extra anon pages too large: %#x", memSize-mapSize)
return syserror.ENOEXEC
@@ -339,7 +340,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
// N.B. Linux uses vm_brk_flags to map these pages, which only
// honors the X bit, always mapping at least RW. ignoring These
// pages are not included in the final brk region.
- prot := usermem.ReadWrite
+ prot := hostarch.ReadWrite
if phdr.Flags&elf.PF_X == elf.PF_X {
prot.Execute = true
}
@@ -352,7 +353,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
Fixed: true,
Private: true,
Perms: prot,
- MaxPerms: usermem.AnyAccess,
+ MaxPerms: hostarch.AnyAccess,
}); err != nil {
ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
return err
@@ -371,19 +372,19 @@ type loadedELF struct {
arch arch.Arch
// entry is the entry point of the ELF.
- entry usermem.Addr
+ entry hostarch.Addr
// start is the end of the ELF.
- start usermem.Addr
+ start hostarch.Addr
// end is the end of the ELF.
- end usermem.Addr
+ end hostarch.Addr
// interpter is the path to the ELF interpreter.
interpreter string
// phdrAddr is the address of the ELF program headers.
- phdrAddr usermem.Addr
+ phdrAddr hostarch.Addr
// phdrSize is the size of a single program header in the ELF.
phdrSize int
@@ -407,14 +408,14 @@ type loadedELF struct {
// It does not load the ELF interpreter, or return any auxv entries.
//
// Preconditions: f is an ELF file.
-func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
+func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset hostarch.Addr) (loadedELF, error) {
first := true
- var start, end usermem.Addr
+ var start, end hostarch.Addr
var interpreter string
for _, phdr := range info.phdrs {
switch phdr.Type {
case elf.PT_LOAD:
- vaddr := usermem.Addr(phdr.Vaddr)
+ vaddr := hostarch.Addr(phdr.Vaddr)
if first {
first = false
start = vaddr
@@ -492,7 +493,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in
// Note that the vaddr of the first PT_LOAD segment is ignored when
// choosing the load address (even if it is non-zero). The vaddr does
// become an offset from that load address.
- var offset usermem.Addr
+ var offset hostarch.Addr
if info.sharedObject {
totalSize := end - start
totalSize, ok := totalSize.RoundUp()
@@ -688,8 +689,8 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error
// ELF-specific auxv entries.
bin.auxv = arch.Auxv{
arch.AuxEntry{linux.AT_PHDR, bin.phdrAddr},
- arch.AuxEntry{linux.AT_PHENT, usermem.Addr(bin.phdrSize)},
- arch.AuxEntry{linux.AT_PHNUM, usermem.Addr(bin.phdrNum)},
+ arch.AuxEntry{linux.AT_PHENT, hostarch.Addr(bin.phdrSize)},
+ arch.AuxEntry{linux.AT_PHNUM, hostarch.Addr(bin.phdrNum)},
arch.AuxEntry{linux.AT_ENTRY, bin.entry},
}
if bin.interpreter != "" {
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index c69b62db9..47e3775a3 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -266,17 +267,17 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
// Add generic auxv entries.
auxv := append(loaded.auxv, arch.Auxv{
- arch.AuxEntry{linux.AT_UID, usermem.Addr(c.RealKUID.In(c.UserNamespace).OrOverflow())},
- arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
- arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
- arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_UID, hostarch.Addr(c.RealKUID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_EUID, hostarch.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_GID, hostarch.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
+ arch.AuxEntry{linux.AT_EGID, hostarch.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
// The conditions that require AT_SECURE = 1 never arise. See
// kernel.Task.updateCredsForExecLocked.
arch.AuxEntry{linux.AT_SECURE, 0},
arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC},
arch.AuxEntry{linux.AT_EXECFN, execfn},
arch.AuxEntry{linux.AT_RANDOM, random},
- arch.AuxEntry{linux.AT_PAGESZ, usermem.PageSize},
+ arch.AuxEntry{linux.AT_PAGESZ, hostarch.PageSize},
arch.AuxEntry{linux.AT_SYSINFO_EHDR, vdsoAddr},
}...)
auxv = append(auxv, extraAuxv...)
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index a32d37d62..fd54261fd 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -90,7 +91,7 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro
var first *elf.ProgHeader
var prev *elf.ProgHeader
- var prevEnd usermem.Addr
+ var prevEnd hostarch.Addr
for i, phdr := range info.phdrs {
if phdr.Type != elf.PT_LOAD {
continue
@@ -119,7 +120,7 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro
return elfInfo{}, syserror.ENOEXEC
}
- start := usermem.Addr(memoryOffset)
+ start := hostarch.Addr(memoryOffset)
end, ok := start.AddLength(phdr.Memsz)
if !ok {
log.Warningf("PT_LOAD segment size overflows: %#x + %#x", start, end)
@@ -210,7 +211,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
}
// Then copy it into a VDSO mapping.
- size, ok := usermem.Addr(len(vdsodata.Binary)).RoundUp()
+ size, ok := hostarch.Addr(len(vdsodata.Binary)).RoundUp()
if !ok {
return nil, fmt.Errorf("VDSO size overflows? %#x", len(vdsodata.Binary))
}
@@ -221,7 +222,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
return nil, fmt.Errorf("unable to allocate VDSO memory: %v", err)
}
- ims, err := mf.MapInternal(vdso, usermem.ReadWrite)
+ ims, err := mf.MapInternal(vdso, hostarch.ReadWrite)
if err != nil {
mf.DecRef(vdso)
return nil, fmt.Errorf("unable to map VDSO memory: %v", err)
@@ -234,7 +235,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
}
// Finally, allocate a param page for this VDSO.
- paramPage, err := mf.Allocate(usermem.PageSize, usage.System)
+ paramPage, err := mf.Allocate(hostarch.PageSize, usage.System)
if err != nil {
mf.DecRef(vdso)
return nil, fmt.Errorf("unable to allocate VDSO param page: %v", err)
@@ -266,7 +267,7 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
// compatibility with such binaries, we load the VDSO much like Linux.
//
// loadVDSO takes a reference on the VDSO and parameter page FrameRegions.
-func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (usermem.Addr, error) {
+func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (hostarch.Addr, error) {
if v.os != bin.os {
ctx.Warningf("Binary ELF OS %v and VDSO ELF OS %v differ", bin.os, v.os)
return 0, syserror.ENOEXEC
@@ -297,8 +298,8 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF)
Fixed: true,
Unmap: true,
Private: true,
- Perms: usermem.Read,
- MaxPerms: usermem.Read,
+ Perms: hostarch.Read,
+ MaxPerms: hostarch.Read,
})
if err != nil {
ctx.Infof("Unable to map VDSO param page: %v", err)
@@ -318,8 +319,8 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF)
Fixed: true,
Unmap: true,
Private: true,
- Perms: usermem.Read,
- MaxPerms: usermem.AnyAccess,
+ Perms: hostarch.Read,
+ MaxPerms: hostarch.AnyAccess,
})
if err != nil {
ctx.Infof("Unable to map VDSO: %v", err)
@@ -349,7 +350,7 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF)
return 0, syserror.ENOEXEC
}
segPage := segAddr.RoundDown()
- segSize := usermem.Addr(phdr.Memsz)
+ segSize := hostarch.Addr(phdr.Memsz)
segSize, ok = segSize.AddLength(segAddr.PageOffset())
if !ok {
ctx.Warningf("PT_LOAD segment memsize %#x + offset %#x overflows", phdr.Memsz, segAddr.PageOffset())
@@ -371,7 +372,7 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF)
}
perms := progFlagsAsPerms(phdr.Flags)
- if perms != usermem.Read {
+ if perms != hostarch.Read {
if err := m.MProtect(segPage, uint64(segSize), perms, false); err != nil {
ctx.Warningf("Unable to set PT_LOAD segment protections %+v at [%#x, %#x): %v", perms, segAddr, segEnd, err)
return 0, syserror.ENOEXEC
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 2c95669cd..c30e88725 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -51,6 +51,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/safemem",
"//pkg/syserror",
@@ -63,5 +64,5 @@ go_test(
size = "small",
srcs = ["mapping_set_test.go"],
library = ":memmap",
- deps = ["//pkg/usermem"],
+ deps = ["//pkg/hostarch"],
)
diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go
index 457ed87f8..32863bb5e 100644
--- a/pkg/sentry/memmap/mapping_set.go
+++ b/pkg/sentry/memmap/mapping_set.go
@@ -18,7 +18,7 @@ import (
"fmt"
"math"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// MappingSet maps offsets into a Mappable to mappings of those offsets. It is
@@ -39,7 +39,7 @@ type MappingsOfRange map[MappingOfRange]struct{}
// +stateify savable
type MappingOfRange struct {
MappingSpace MappingSpace
- AddrRange usermem.AddrRange
+ AddrRange hostarch.AddrRange
Writable bool
}
@@ -89,9 +89,9 @@ func (mappingSetFunctions) Merge(r1 MappableRange, val1 MappingsOfRange, r2 Mapp
// region with k1.
k2 := MappingOfRange{
MappingSpace: k1.MappingSpace,
- AddrRange: usermem.AddrRange{
+ AddrRange: hostarch.AddrRange{
Start: k1.AddrRange.End,
- End: k1.AddrRange.End + usermem.Addr(r2.Length()),
+ End: k1.AddrRange.End + hostarch.Addr(r2.Length()),
},
Writable: k1.Writable,
}
@@ -102,7 +102,7 @@ func (mappingSetFunctions) Merge(r1 MappableRange, val1 MappingsOfRange, r2 Mapp
// OK. Add it to the merged map.
merged[MappingOfRange{
MappingSpace: k1.MappingSpace,
- AddrRange: usermem.AddrRange{
+ AddrRange: hostarch.AddrRange{
Start: k1.AddrRange.Start,
End: k2.AddrRange.End,
},
@@ -124,11 +124,11 @@ func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uin
// split is a value in MappableRange, we need the offset into the
// corresponding MappingsOfRange.
- offset := usermem.Addr(split - r.Start)
+ offset := hostarch.Addr(split - r.Start)
for k := range val {
k1 := MappingOfRange{
MappingSpace: k.MappingSpace,
- AddrRange: usermem.AddrRange{
+ AddrRange: hostarch.AddrRange{
Start: k.AddrRange.Start,
End: k.AddrRange.Start + offset,
},
@@ -138,7 +138,7 @@ func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uin
k2 := MappingOfRange{
MappingSpace: k.MappingSpace,
- AddrRange: usermem.AddrRange{
+ AddrRange: hostarch.AddrRange{
Start: k.AddrRange.Start + offset,
End: k.AddrRange.End,
},
@@ -157,18 +157,18 @@ func (mappingSetFunctions) Split(r MappableRange, val MappingsOfRange, split uin
// indicating that ms maps addresses [0x4000, 0x6000) to MappableRange [0x0,
// 0x2000). Then for subsetRange = [0x1000, 0x2000), subsetMapping returns a
// MappingOfRange for which AddrRange = [0x5000, 0x6000).
-func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr usermem.Addr, writable bool) MappingOfRange {
+func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr hostarch.Addr, writable bool) MappingOfRange {
if !wholeRange.IsSupersetOf(subsetRange) {
panic(fmt.Sprintf("%v is not a superset of %v", wholeRange, subsetRange))
}
offset := subsetRange.Start - wholeRange.Start
- start := addr + usermem.Addr(offset)
+ start := addr + hostarch.Addr(offset)
return MappingOfRange{
MappingSpace: ms,
- AddrRange: usermem.AddrRange{
+ AddrRange: hostarch.AddrRange{
Start: start,
- End: start + usermem.Addr(subsetRange.Length()),
+ End: start + hostarch.Addr(subsetRange.Length()),
},
Writable: writable,
}
@@ -178,7 +178,7 @@ func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr
// previously had no mappings.
//
// Preconditions: Same as Mappable.AddMapping.
-func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange {
+func (s *MappingSet) AddMapping(ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) []MappableRange {
mr := MappableRange{offset, offset + uint64(ar.Length())}
var mapped []MappableRange
seg, gap := s.Find(mr.Start)
@@ -205,7 +205,7 @@ func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset ui
// MappableRanges that now have no mappings.
//
// Preconditions: Same as Mappable.RemoveMapping.
-func (s *MappingSet) RemoveMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange {
+func (s *MappingSet) RemoveMapping(ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) []MappableRange {
mr := MappableRange{offset, offset + uint64(ar.Length())}
var unmapped []MappableRange
diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go
index d39efe38f..5cb81fde7 100644
--- a/pkg/sentry/memmap/mapping_set_test.go
+++ b/pkg/sentry/memmap/mapping_set_test.go
@@ -15,24 +15,23 @@
package memmap
import (
+ "gvisor.dev/gvisor/pkg/hostarch"
"reflect"
"testing"
-
- "gvisor.dev/gvisor/pkg/usermem"
)
type testMappingSpace struct {
// Ideally we'd store the full ranges that were invalidated, rather
// than individual calls to Invalidate, as they are an implementation
// detail, but this is the simplest way for now.
- inv []usermem.AddrRange
+ inv []hostarch.AddrRange
}
func (n *testMappingSpace) reset() {
- n.inv = []usermem.AddrRange{}
+ n.inv = []hostarch.AddrRange{}
}
-func (n *testMappingSpace) Invalidate(ar usermem.AddrRange, opts InvalidateOpts) {
+func (n *testMappingSpace) Invalidate(ar hostarch.AddrRange, opts InvalidateOpts) {
n.inv = append(n.inv, ar)
}
@@ -40,16 +39,16 @@ func TestAddRemoveMapping(t *testing.T) {
set := MappingSet{}
ms := &testMappingSpace{}
- mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true)
+ mapped := set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true)
if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) {
t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
}
- // Mappings (usermem.AddrRanges => memmap.MappableRange):
+ // Mappings (hostarch.AddrRanges => memmap.MappableRange):
// [0x10000, 0x12000) => [0x1000, 0x3000)
t.Log(&set)
- mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ mapped = set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true)
if len(mapped) != 0 {
t.Errorf("AddMapping: got %+v, wanted []", mapped)
}
@@ -59,7 +58,7 @@ func TestAddRemoveMapping(t *testing.T) {
// [0x11000, 0x12000) and [0x20000, 0x21000) => [0x2000, 0x3000)
t.Log(&set)
- mapped = set.AddMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true)
+ mapped = set.AddMapping(ms, hostarch.AddrRange{0x30000, 0x31000}, 0x4000, true)
if got, want := mapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) {
t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
}
@@ -70,7 +69,7 @@ func TestAddRemoveMapping(t *testing.T) {
// [0x30000, 0x31000) => [0x4000, 0x5000)
t.Log(&set)
- mapped = set.AddMapping(ms, usermem.AddrRange{0x12000, 0x15000}, 0x3000, true)
+ mapped = set.AddMapping(ms, hostarch.AddrRange{0x12000, 0x15000}, 0x3000, true)
if got, want := mapped, []MappableRange{{0x3000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) {
t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
}
@@ -83,7 +82,7 @@ func TestAddRemoveMapping(t *testing.T) {
// [0x14000, 0x15000) => [0x5000, 0x6000)
t.Log(&set)
- unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0x1000, true)
+ unmapped := set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0x1000, true)
if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) {
t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
}
@@ -95,7 +94,7 @@ func TestAddRemoveMapping(t *testing.T) {
// [0x14000, 0x15000) => [0x5000, 0x6000)
t.Log(&set)
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true)
if len(unmapped) != 0 {
t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
}
@@ -106,7 +105,7 @@ func TestAddRemoveMapping(t *testing.T) {
// [0x14000, 0x15000) => [0x5000, 0x6000)
t.Log(&set)
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x15000}, 0x2000, true)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x11000, 0x15000}, 0x2000, true)
if got, want := unmapped, []MappableRange{{0x2000, 0x4000}, {0x5000, 0x6000}}; !reflect.DeepEqual(got, want) {
t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
}
@@ -115,7 +114,7 @@ func TestAddRemoveMapping(t *testing.T) {
// [0x30000, 0x31000) => [0x4000, 0x5000)
t.Log(&set)
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x30000, 0x31000}, 0x4000, true)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x30000, 0x31000}, 0x4000, true)
if got, want := unmapped, []MappableRange{{0x4000, 0x5000}}; !reflect.DeepEqual(got, want) {
t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
}
@@ -125,12 +124,12 @@ func TestInvalidateWholeMapping(t *testing.T) {
set := MappingSet{}
ms := &testMappingSpace{}
- set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true)
+ set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0, true)
// Mappings:
// [0x10000, 0x11000) => [0, 0x1000)
t.Log(&set)
set.Invalidate(MappableRange{0, 0x1000}, InvalidateOpts{})
- if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) {
+ if got, want := ms.inv, []hostarch.AddrRange{{0x10000, 0x11000}}; !reflect.DeepEqual(got, want) {
t.Errorf("Invalidate: got %+v, wanted %+v", got, want)
}
}
@@ -139,12 +138,12 @@ func TestInvalidatePartialMapping(t *testing.T) {
set := MappingSet{}
ms := &testMappingSpace{}
- set.AddMapping(ms, usermem.AddrRange{0x10000, 0x13000}, 0, true)
+ set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x13000}, 0, true)
// Mappings:
// [0x10000, 0x13000) => [0, 0x3000)
t.Log(&set)
set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{})
- if got, want := ms.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) {
+ if got, want := ms.inv, []hostarch.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) {
t.Errorf("Invalidate: got %+v, wanted %+v", got, want)
}
}
@@ -153,14 +152,14 @@ func TestInvalidateMultipleMappings(t *testing.T) {
set := MappingSet{}
ms := &testMappingSpace{}
- set.AddMapping(ms, usermem.AddrRange{0x10000, 0x11000}, 0, true)
- set.AddMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x11000}, 0, true)
+ set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true)
// Mappings:
// [0x10000, 0x11000) => [0, 0x1000)
// [0x12000, 0x13000) => [0x2000, 0x3000)
t.Log(&set)
set.Invalidate(MappableRange{0, 0x3000}, InvalidateOpts{})
- if got, want := ms.inv, []usermem.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) {
+ if got, want := ms.inv, []hostarch.AddrRange{{0x10000, 0x11000}, {0x20000, 0x21000}}; !reflect.DeepEqual(got, want) {
t.Errorf("Invalidate: got %+v, wanted %+v", got, want)
}
}
@@ -170,17 +169,17 @@ func TestInvalidateOverlappingMappings(t *testing.T) {
ms1 := &testMappingSpace{}
ms2 := &testMappingSpace{}
- set.AddMapping(ms1, usermem.AddrRange{0x10000, 0x12000}, 0, true)
- set.AddMapping(ms2, usermem.AddrRange{0x20000, 0x22000}, 0x1000, true)
+ set.AddMapping(ms1, hostarch.AddrRange{0x10000, 0x12000}, 0, true)
+ set.AddMapping(ms2, hostarch.AddrRange{0x20000, 0x22000}, 0x1000, true)
// Mappings:
// ms1:[0x10000, 0x12000) => [0, 0x2000)
// ms2:[0x11000, 0x13000) => [0x1000, 0x3000)
t.Log(&set)
set.Invalidate(MappableRange{0x1000, 0x2000}, InvalidateOpts{})
- if got, want := ms1.inv, []usermem.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) {
+ if got, want := ms1.inv, []hostarch.AddrRange{{0x11000, 0x12000}}; !reflect.DeepEqual(got, want) {
t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want)
}
- if got, want := ms2.inv, []usermem.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) {
+ if got, want := ms2.inv, []hostarch.AddrRange{{0x20000, 0x21000}}; !reflect.DeepEqual(got, want) {
t.Errorf("Invalidate: ms1: got %+v, wanted %+v", got, want)
}
}
@@ -189,7 +188,7 @@ func TestMixedWritableMappings(t *testing.T) {
set := MappingSet{}
ms := &testMappingSpace{}
- mapped := set.AddMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true)
+ mapped := set.AddMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true)
if got, want := mapped, []MappableRange{{0x1000, 0x3000}}; !reflect.DeepEqual(got, want) {
t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
}
@@ -198,7 +197,7 @@ func TestMixedWritableMappings(t *testing.T) {
// [0x10000, 0x12000) writable => [0x1000, 0x3000)
t.Log(&set)
- mapped = set.AddMapping(ms, usermem.AddrRange{0x20000, 0x22000}, 0x2000, false)
+ mapped = set.AddMapping(ms, hostarch.AddrRange{0x20000, 0x22000}, 0x2000, false)
if got, want := mapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) {
t.Errorf("AddMapping: got %+v, wanted %+v", got, want)
}
@@ -211,14 +210,14 @@ func TestMixedWritableMappings(t *testing.T) {
// Unmap should fail because we specified the readonly map address range, but
// asked to unmap a writable segment.
- unmapped := set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, true)
+ unmapped := set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, true)
if len(unmapped) != 0 {
t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
}
// Readonly mapping removed, but writable mapping still exists in the range,
// so no mappable range fully unmapped.
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x20000, 0x21000}, 0x2000, false)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x20000, 0x21000}, 0x2000, false)
if len(unmapped) != 0 {
t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
}
@@ -228,7 +227,7 @@ func TestMixedWritableMappings(t *testing.T) {
// [0x21000, 0x22000) readonly => [0x3000, 0x4000)
t.Log(&set)
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x11000, 0x12000}, 0x2000, true)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x11000, 0x12000}, 0x2000, true)
if got, want := unmapped, []MappableRange{{0x2000, 0x3000}}; !reflect.DeepEqual(got, want) {
t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
}
@@ -239,12 +238,12 @@ func TestMixedWritableMappings(t *testing.T) {
t.Log(&set)
// Unmap should fail since writable bit doesn't match.
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, false)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, false)
if len(unmapped) != 0 {
t.Errorf("RemoveMapping: got %+v, wanted []", unmapped)
}
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x10000, 0x12000}, 0x1000, true)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x10000, 0x12000}, 0x1000, true)
if got, want := unmapped, []MappableRange{{0x1000, 0x2000}}; !reflect.DeepEqual(got, want) {
t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
}
@@ -253,7 +252,7 @@ func TestMixedWritableMappings(t *testing.T) {
// [0x21000, 0x22000) readonly => [0x3000, 0x4000)
t.Log(&set)
- unmapped = set.RemoveMapping(ms, usermem.AddrRange{0x21000, 0x22000}, 0x3000, false)
+ unmapped = set.RemoveMapping(ms, hostarch.AddrRange{0x21000, 0x22000}, 0x3000, false)
if got, want := unmapped, []MappableRange{{0x3000, 0x4000}}; !reflect.DeepEqual(got, want) {
t.Errorf("RemoveMapping: got %+v, wanted %+v", got, want)
}
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 49e21026e..610686ea0 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -19,8 +19,8 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Mappable represents a memory-mappable object, a mutable mapping from uint64
@@ -29,8 +29,8 @@ import (
// See mm/mm.go for Mappable's place in the lock order.
//
// All Mappable methods have the following preconditions:
-// * usermem.AddrRanges and MappableRanges must be non-empty (Length() != 0).
-// * usermem.Addrs and Mappable offsets must be page-aligned.
+// * hostarch.AddrRanges and MappableRanges must be non-empty (Length() != 0).
+// * hostarch.Addrs and Mappable offsets must be page-aligned.
type Mappable interface {
// AddMapping notifies the Mappable of a mapping from addresses ar in ms to
// offsets [offset, offset+ar.Length()) in this Mappable.
@@ -42,7 +42,7 @@ type Mappable interface {
// lifetime of the mapping.
//
// Preconditions: offset+ar.Length() does not overflow.
- AddMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error
+ AddMapping(ctx context.Context, ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error
// RemoveMapping notifies the Mappable of the removal of a mapping from
// addresses ar in ms to offsets [offset, offset+ar.Length()) in this
@@ -52,7 +52,7 @@ type Mappable interface {
// * offset+ar.Length() does not overflow.
// * The removed mapping must exist. writable must match the
// corresponding call to AddMapping.
- RemoveMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool)
+ RemoveMapping(ctx context.Context, ms MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool)
// CopyMapping notifies the Mappable of an attempt to copy a mapping in ms
// from srcAR to dstAR. For most Mappables, this is equivalent to
@@ -66,7 +66,7 @@ type Mappable interface {
// * offset+srcAR.Length() and offset+dstAR.Length() do not overflow.
// * The mapping at srcAR must exist. writable must match the
// corresponding call to AddMapping.
- CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error
+ CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error
// Translate returns the Mappable's current mappings for at least the range
// of offsets specified by required, and at most the range of offsets
@@ -90,7 +90,7 @@ type Mappable interface {
// synchronize with invalidation.
//
// Postconditions: See CheckTranslateResult.
- Translate(ctx context.Context, required, optional MappableRange, at usermem.AccessType) ([]Translation, error)
+ Translate(ctx context.Context, required, optional MappableRange, at hostarch.AccessType) ([]Translation, error)
// InvalidateUnsavable requests that the Mappable invalidate Translations
// that cannot be preserved across save/restore.
@@ -113,7 +113,7 @@ type Translation struct {
// Perms is the set of permissions for which platform.AddressSpace.MapFile
// and platform.AddressSpace.MapInternal on this Translation is permitted.
- Perms usermem.AccessType
+ Perms hostarch.AccessType
}
// FileRange returns the FileRange represented by t.
@@ -125,18 +125,18 @@ func (t Translation) FileRange() FileRange {
// postconditions for Mappable.Translate(required, optional, at).
//
// Preconditions: Same as Mappable.Translate.
-func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error {
+func CheckTranslateResult(required, optional MappableRange, at hostarch.AccessType, ts []Translation, terr error) error {
// Verify that the inputs to Mappable.Translate were valid.
if !required.WellFormed() || required.Length() == 0 {
panic(fmt.Sprintf("invalid required range: %v", required))
}
- if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() {
+ if !hostarch.Addr(required.Start).IsPageAligned() || !hostarch.Addr(required.End).IsPageAligned() {
panic(fmt.Sprintf("unaligned required range: %v", required))
}
if !optional.IsSupersetOf(required) {
panic(fmt.Sprintf("optional range %v is not a superset of required range %v", optional, required))
}
- if !usermem.Addr(optional.Start).IsPageAligned() || !usermem.Addr(optional.End).IsPageAligned() {
+ if !hostarch.Addr(optional.Start).IsPageAligned() || !hostarch.Addr(optional.End).IsPageAligned() {
panic(fmt.Sprintf("unaligned optional range: %v", optional))
}
@@ -148,13 +148,13 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
if !t.Source.WellFormed() || t.Source.Length() == 0 {
return fmt.Errorf("Translation %+v has invalid Source", t)
}
- if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() {
+ if !hostarch.Addr(t.Source.Start).IsPageAligned() || !hostarch.Addr(t.Source.End).IsPageAligned() {
return fmt.Errorf("Translation %+v has unaligned Source", t)
}
if t.File == nil {
return fmt.Errorf("Translation %+v has nil File", t)
}
- if !usermem.Addr(t.Offset).IsPageAligned() {
+ if !hostarch.Addr(t.Offset).IsPageAligned() {
return fmt.Errorf("Translation %+v has unaligned Offset", t)
}
// Translations must be contiguous and in increasing order of
@@ -210,7 +210,7 @@ func (mr MappableRange) String() string {
return fmt.Sprintf("[%#x, %#x)", mr.Start, mr.End)
}
-// MappingSpace represents a mutable mapping from usermem.Addrs to (Mappable,
+// MappingSpace represents a mutable mapping from hostarch.Addrs to (Mappable,
// uint64 offset) pairs.
type MappingSpace interface {
// Invalidate is called to notify the MappingSpace that values returned by
@@ -223,7 +223,7 @@ type MappingSpace interface {
// Preconditions:
// * ar.Length() != 0.
// * ar must be page-aligned.
- Invalidate(ar usermem.AddrRange, opts InvalidateOpts)
+ Invalidate(ar hostarch.AddrRange, opts InvalidateOpts)
}
// InvalidateOpts holds options to MappingSpace.Invalidate.
@@ -321,7 +321,7 @@ type MMapOpts struct {
Offset uint64
// Addr is the suggested address for the mapping.
- Addr usermem.Addr
+ Addr hostarch.Addr
// Fixed specifies whether this is a fixed mapping (it must be located at
// Addr).
@@ -338,7 +338,7 @@ type MMapOpts struct {
Map32Bit bool
// Perms is the set of permissions to the applied to this mapping.
- Perms usermem.AccessType
+ Perms hostarch.AccessType
// MaxPerms limits the set of permissions that may ever apply to this
// mapping. If Mappable is not nil, all memmap.Translations returned by
@@ -346,7 +346,7 @@ type MMapOpts struct {
//
// Preconditions: MaxAccessType should be an effective AccessType, as
// access cannot be limited beyond effective AccessTypes.
- MaxPerms usermem.AccessType
+ MaxPerms hostarch.AccessType
// Private is true if writes to the mapping should be propagated to a copy
// that is exclusive to the MemoryManager.
@@ -375,6 +375,11 @@ type MMapOpts struct {
//
// If Force is true, Unmap and Fixed must be true.
Force bool
+
+ // SentryOwnedContent indicates the sentry exclusively controls the
+ // underlying memory backing the mapping thus the memory content is
+ // guaranteed not to be modified outside the sentry's purview.
+ SentryOwnedContent bool
}
// File represents a host file that may be mapped into an platform.AddressSpace.
@@ -410,7 +415,7 @@ type File interface {
//
// Postconditions: The returned mapping is valid as long as at least one
// reference is held on the mapped pages.
- MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
+ MapInternal(fr FileRange, at hostarch.AccessType) (safemem.BlockSeq, error)
// FD returns the file descriptor represented by the File.
//
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 6dbeccfe2..b417c2da7 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -28,14 +28,14 @@ go_template_instance(
"trackGaps": "1",
},
imports = {
- "usermem": "gvisor.dev/gvisor/pkg/usermem",
+ "hostarch": "gvisor.dev/gvisor/pkg/hostarch",
},
package = "mm",
prefix = "vma",
template = "//pkg/segment:generic_set",
types = {
- "Key": "usermem.Addr",
- "Range": "usermem.AddrRange",
+ "Key": "hostarch.Addr",
+ "Range": "hostarch.AddrRange",
"Value": "vma",
"Functions": "vmaSetFunctions",
},
@@ -48,14 +48,14 @@ go_template_instance(
"minDegree": "8",
},
imports = {
- "usermem": "gvisor.dev/gvisor/pkg/usermem",
+ "hostarch": "gvisor.dev/gvisor/pkg/hostarch",
},
package = "mm",
prefix = "pma",
template = "//pkg/segment:generic_set",
types = {
- "Key": "usermem.Addr",
- "Range": "usermem.AddrRange",
+ "Key": "hostarch.Addr",
+ "Range": "hostarch.AddrRange",
"Value": "pma",
"Functions": "pmaSetFunctions",
},
@@ -125,6 +125,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/atomicbitops",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
@@ -155,6 +156,7 @@ go_test(
library = ":mm",
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/arch",
"//pkg/sentry/contexttest",
"//pkg/sentry/limits",
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index a93e76c75..534e0e957 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -19,8 +19,8 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/usermem"
)
// AddressSpace returns the platform.AddressSpace bound to mm.
@@ -172,17 +172,17 @@ func (mm *MemoryManager) Deactivate() {
// * ar.Length() != 0.
// * ar must be page-aligned.
// * pseg == mm.pmas.LowerBoundSegment(ar.Start).
-func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, precommit bool) error {
+func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar hostarch.AddrRange, precommit bool) error {
// By default, map entire pmas at a time, under the assumption that there
// is no cost to mapping more of a pma than necessary.
- mapAR := usermem.AddrRange{0, ^usermem.Addr(usermem.PageSize - 1)}
+ mapAR := hostarch.AddrRange{0, ^hostarch.Addr(hostarch.PageSize - 1)}
if precommit {
// When explicitly precommitting, only map ar, since overmapping may
// incur unexpected resource usage.
mapAR = ar
} else if mapUnit := mm.p.MapUnit(); mapUnit != 0 {
// Limit the range we map to ar, aligned to mapUnit.
- mapMask := usermem.Addr(mapUnit - 1)
+ mapMask := hostarch.Addr(mapUnit - 1)
mapAR.Start = ar.Start &^ mapMask
// If rounding ar.End up overflows, just keep the existing mapAR.End.
if end := (ar.End + mapMask) &^ mapMask; end >= ar.End {
@@ -218,7 +218,7 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, pre
// unmapASLocked removes all AddressSpace mappings for addresses in ar.
//
// Preconditions: mm.activeMu must be locked.
-func (mm *MemoryManager) unmapASLocked(ar usermem.AddrRange) {
+func (mm *MemoryManager) unmapASLocked(ar hostarch.AddrRange) {
if mm.as == nil {
// No AddressSpace? Force all mappings to be unmapped on the next
// Activate.
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 5ab2ef79f..346866d3c 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -17,6 +17,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
@@ -83,7 +84,7 @@ func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64)
// the same address. Then it would be unmapping memory that it doesn't own.
// This is, however, the way Linux implements AIO. Keeps the same [weird]
// semantics in case anyone relies on it.
- mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
+ mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize)
delete(mm.aioManager.contexts, id)
aioCtx.destroy()
@@ -259,7 +260,7 @@ type aioMappable struct {
fr memmap.FileRange
}
-var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp())
+var aioRingBufferSize = uint64(hostarch.Addr(linux.AIORingSize).MustRoundUp())
func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) {
fr, err := mfp.MemoryFile().Allocate(aioRingBufferSize, usage.Anonymous)
@@ -300,7 +301,7 @@ func (m *aioMappable) Msync(ctx context.Context, mr memmap.MappableRange) error
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar usermem.AddrRange, offset uint64, _ bool) error {
+func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, _ bool) error {
// Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap()
// sets VM_DONTEXPAND).
if offset != 0 || uint64(ar.Length()) != aioRingBufferSize {
@@ -310,11 +311,11 @@ func (m *aioMappable) AddMapping(_ context.Context, _ memmap.MappingSpace, ar us
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (m *aioMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) {
+func (m *aioMappable) RemoveMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, uint64, bool) {
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, _ bool) error {
+func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, _ bool) error {
// Don't allow mappings to be expanded (in Linux, fs/aio.c:aio_ring_mmap()
// sets VM_DONTEXPAND).
if offset != 0 || uint64(dstAR.Length()) != aioRingBufferSize {
@@ -346,7 +347,7 @@ func (m *aioMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, s
}
// Translate implements memmap.Mappable.Translate.
-func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
var err error
if required.End > m.fr.Length() {
err = &memmap.BusError{syserror.EFAULT}
@@ -357,7 +358,7 @@ func (m *aioMappable) Translate(ctx context.Context, required, optional memmap.M
Source: source,
File: m.mfp.MemoryFile(),
Offset: m.fr.Start + source.Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
},
}, err
}
@@ -389,8 +390,8 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
// Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in
// fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC,
// user mode should not write to this page.
- Perms: usermem.Read,
- MaxPerms: usermem.Read,
+ Perms: hostarch.Read,
+ MaxPerms: hostarch.Read,
})
if err != nil {
return 0, err
@@ -435,6 +436,6 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC
// bytes from id).
func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool {
var buf [4]byte
- _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
+ _, err := mm.CopyIn(ctx, hostarch.Addr(id), buf[:], usermem.IOOpts{})
return err == nil
}
diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go
index a8ac48080..16f318ab3 100644
--- a/pkg/sentry/mm/io.go
+++ b/pkg/sentry/mm/io.go
@@ -16,6 +16,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/syserror"
@@ -60,11 +61,11 @@ const (
rwMapMinBytes = 512
)
-// CheckIORange is similar to usermem.Addr.ToRange, but applies bounds checks
+// CheckIORange is similar to hostarch.Addr.ToRange, but applies bounds checks
// consistent with Linux's arch/x86/include/asm/uaccess.h:access_ok().
//
// Preconditions: length >= 0.
-func (mm *MemoryManager) CheckIORange(addr usermem.Addr, length int64) (usermem.AddrRange, bool) {
+func (mm *MemoryManager) CheckIORange(addr hostarch.Addr, length int64) (hostarch.AddrRange, bool) {
// Note that access_ok() constrains end even if length == 0.
ar, ok := addr.ToRange(uint64(length))
return ar, (ok && ar.End <= mm.layout.MaxAddr)
@@ -72,7 +73,7 @@ func (mm *MemoryManager) CheckIORange(addr usermem.Addr, length int64) (usermem.
// checkIOVec applies bound checks consistent with Linux's
// arch/x86/include/asm/uaccess.h:access_ok() to ars.
-func (mm *MemoryManager) checkIOVec(ars usermem.AddrRangeSeq) bool {
+func (mm *MemoryManager) checkIOVec(ars hostarch.AddrRangeSeq) bool {
for !ars.IsEmpty() {
ar := ars.Head()
if _, ok := mm.CheckIORange(ar.Start, int64(ar.Length())); !ok {
@@ -100,7 +101,7 @@ func translateIOError(ctx context.Context, err error) error {
}
// CopyOut implements usermem.IO.CopyOut.
-func (mm *MemoryManager) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+func (mm *MemoryManager) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts usermem.IOOpts) (int, error) {
ar, ok := mm.CheckIORange(addr, int64(len(src)))
if !ok {
return 0, syserror.EFAULT
@@ -116,24 +117,24 @@ func (mm *MemoryManager) CopyOut(ctx context.Context, addr usermem.Addr, src []b
}
// Go through internal mappings.
- n64, err := mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ n64, err := mm.withInternalMappings(ctx, ar, hostarch.Write, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
n, err := safemem.CopySeq(ims, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src)))
return n, translateIOError(ctx, err)
})
return int(n64), err
}
-func (mm *MemoryManager) asCopyOut(ctx context.Context, addr usermem.Addr, src []byte) (int, error) {
+func (mm *MemoryManager) asCopyOut(ctx context.Context, addr hostarch.Addr, src []byte) (int, error) {
var done int
for {
- n, err := mm.as.CopyOut(addr+usermem.Addr(done), src[done:])
+ n, err := mm.as.CopyOut(addr+hostarch.Addr(done), src[done:])
done += n
if err == nil {
return done, nil
}
if f, ok := err.(platform.SegmentationFault); ok {
ar, _ := addr.ToRange(uint64(len(src)))
- if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Write); err != nil {
return done, err
}
continue
@@ -143,7 +144,7 @@ func (mm *MemoryManager) asCopyOut(ctx context.Context, addr usermem.Addr, src [
}
// CopyIn implements usermem.IO.CopyIn.
-func (mm *MemoryManager) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+func (mm *MemoryManager) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
ar, ok := mm.CheckIORange(addr, int64(len(dst)))
if !ok {
return 0, syserror.EFAULT
@@ -159,24 +160,24 @@ func (mm *MemoryManager) CopyIn(ctx context.Context, addr usermem.Addr, dst []by
}
// Go through internal mappings.
- n64, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ n64, err := mm.withInternalMappings(ctx, ar, hostarch.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
n, err := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), ims)
return n, translateIOError(ctx, err)
})
return int(n64), err
}
-func (mm *MemoryManager) asCopyIn(ctx context.Context, addr usermem.Addr, dst []byte) (int, error) {
+func (mm *MemoryManager) asCopyIn(ctx context.Context, addr hostarch.Addr, dst []byte) (int, error) {
var done int
for {
- n, err := mm.as.CopyIn(addr+usermem.Addr(done), dst[done:])
+ n, err := mm.as.CopyIn(addr+hostarch.Addr(done), dst[done:])
done += n
if err == nil {
return done, nil
}
if f, ok := err.(platform.SegmentationFault); ok {
ar, _ := addr.ToRange(uint64(len(dst)))
- if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Read); err != nil {
return done, err
}
continue
@@ -186,7 +187,7 @@ func (mm *MemoryManager) asCopyIn(ctx context.Context, addr usermem.Addr, dst []
}
// ZeroOut implements usermem.IO.ZeroOut.
-func (mm *MemoryManager) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+func (mm *MemoryManager) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
ar, ok := mm.CheckIORange(addr, toZero)
if !ok {
return 0, syserror.EFAULT
@@ -202,23 +203,23 @@ func (mm *MemoryManager) ZeroOut(ctx context.Context, addr usermem.Addr, toZero
}
// Go through internal mappings.
- return mm.withInternalMappings(ctx, ar, usermem.Write, opts.IgnorePermissions, func(dsts safemem.BlockSeq) (uint64, error) {
+ return mm.withInternalMappings(ctx, ar, hostarch.Write, opts.IgnorePermissions, func(dsts safemem.BlockSeq) (uint64, error) {
n, err := safemem.ZeroSeq(dsts)
return n, translateIOError(ctx, err)
})
}
-func (mm *MemoryManager) asZeroOut(ctx context.Context, addr usermem.Addr, toZero int64) (int64, error) {
+func (mm *MemoryManager) asZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64) (int64, error) {
var done int64
for {
- n, err := mm.as.ZeroOut(addr+usermem.Addr(done), uintptr(toZero-done))
+ n, err := mm.as.ZeroOut(addr+hostarch.Addr(done), uintptr(toZero-done))
done += int64(n)
if err == nil {
return done, nil
}
if f, ok := err.(platform.SegmentationFault); ok {
ar, _ := addr.ToRange(uint64(toZero))
- if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Write); err != nil {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Write); err != nil {
return done, err
}
continue
@@ -228,7 +229,7 @@ func (mm *MemoryManager) asZeroOut(ctx context.Context, addr usermem.Addr, toZer
}
// CopyOutFrom implements usermem.IO.CopyOutFrom.
-func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
if !mm.checkIOVec(ars) {
return 0, syserror.EFAULT
}
@@ -269,11 +270,11 @@ func (mm *MemoryManager) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeS
}
// Go through internal mappings.
- return mm.withVecInternalMappings(ctx, ars, usermem.Write, opts.IgnorePermissions, src.ReadToBlocks)
+ return mm.withVecInternalMappings(ctx, ars, hostarch.Write, opts.IgnorePermissions, src.ReadToBlocks)
}
// CopyInTo implements usermem.IO.CopyInTo.
-func (mm *MemoryManager) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+func (mm *MemoryManager) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
if !mm.checkIOVec(ars) {
return 0, syserror.EFAULT
}
@@ -306,11 +307,11 @@ func (mm *MemoryManager) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq,
}
// Go through internal mappings.
- return mm.withVecInternalMappings(ctx, ars, usermem.Read, opts.IgnorePermissions, dst.WriteFromBlocks)
+ return mm.withVecInternalMappings(ctx, ars, hostarch.Read, opts.IgnorePermissions, dst.WriteFromBlocks)
}
// SwapUint32 implements usermem.IO.SwapUint32.
-func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+func (mm *MemoryManager) SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
ar, ok := mm.CheckIORange(addr, 4)
if !ok {
return 0, syserror.EFAULT
@@ -324,7 +325,7 @@ func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new
return old, nil
}
if f, ok := err.(platform.SegmentationFault); ok {
- if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.ReadWrite); err != nil {
return 0, err
}
continue
@@ -335,7 +336,7 @@ func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new
// Go through internal mappings.
var old uint32
- _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ _, err := mm.withInternalMappings(ctx, ar, hostarch.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
// Atomicity is unachievable across mappings.
return 0, syserror.EFAULT
@@ -353,7 +354,7 @@ func (mm *MemoryManager) SwapUint32(ctx context.Context, addr usermem.Addr, new
}
// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
-func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
ar, ok := mm.CheckIORange(addr, 4)
if !ok {
return 0, syserror.EFAULT
@@ -367,7 +368,7 @@ func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem.
return prev, nil
}
if f, ok := err.(platform.SegmentationFault); ok {
- if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.ReadWrite); err != nil {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.ReadWrite); err != nil {
return 0, err
}
continue
@@ -378,7 +379,7 @@ func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem.
// Go through internal mappings.
var prev uint32
- _, err := mm.withInternalMappings(ctx, ar, usermem.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ _, err := mm.withInternalMappings(ctx, ar, hostarch.ReadWrite, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
// Atomicity is unachievable across mappings.
return 0, syserror.EFAULT
@@ -396,7 +397,7 @@ func (mm *MemoryManager) CompareAndSwapUint32(ctx context.Context, addr usermem.
}
// LoadUint32 implements usermem.IO.LoadUint32.
-func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+func (mm *MemoryManager) LoadUint32(ctx context.Context, addr hostarch.Addr, opts usermem.IOOpts) (uint32, error) {
ar, ok := mm.CheckIORange(addr, 4)
if !ok {
return 0, syserror.EFAULT
@@ -410,7 +411,7 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts
return val, nil
}
if f, ok := err.(platform.SegmentationFault); ok {
- if err := mm.handleASIOFault(ctx, f.Addr, ar, usermem.Read); err != nil {
+ if err := mm.handleASIOFault(ctx, f.Addr, ar, hostarch.Read); err != nil {
return 0, err
}
continue
@@ -421,7 +422,7 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts
// Go through internal mappings.
var val uint32
- _, err := mm.withInternalMappings(ctx, ar, usermem.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
+ _, err := mm.withInternalMappings(ctx, ar, hostarch.Read, opts.IgnorePermissions, func(ims safemem.BlockSeq) (uint64, error) {
if ims.NumBlocks() != 1 || ims.NumBytes() != 4 {
// Atomicity is unachievable across mappings.
return 0, syserror.EFAULT
@@ -445,11 +446,11 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts
// * mm.as != nil.
// * ioar.Length() != 0.
// * ioar.Contains(addr).
-func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr, ioar usermem.AddrRange, at usermem.AccessType) error {
+func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr hostarch.Addr, ioar hostarch.AddrRange, at hostarch.AccessType) error {
// Try to map all remaining pages in the I/O operation. This RoundUp can't
// overflow because otherwise it would have been caught by CheckIORange.
end, _ := ioar.End.RoundUp()
- ar := usermem.AddrRange{addr.RoundDown(), end}
+ ar := hostarch.AddrRange{addr.RoundDown(), end}
// Don't bother trying existingPMAsLocked; in most cases, if we did have
// existing pmas, we wouldn't have faulted.
@@ -498,7 +499,7 @@ func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr,
// more useful for usermem.IO methods.
//
// Preconditions: 0 < ar.Length() <= math.MaxInt64.
-func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
+func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
// If pmas are already available, we can do IO without touching mm.vmas or
// mm.mappingMu.
mm.activeMu.RLock()
@@ -567,7 +568,7 @@ func (mm *MemoryManager) withInternalMappings(ctx context.Context, ar usermem.Ad
// internal mappings for the subset of ars for which this property holds.
//
// Preconditions: !ars.IsEmpty().
-func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
+func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars hostarch.AddrRangeSeq, at hostarch.AccessType, ignorePermissions bool, f func(safemem.BlockSeq) (uint64, error)) (int64, error) {
// withInternalMappings is faster than withVecInternalMappings because of
// iterator plumbing (this isn't generally practical in the vector case due
// to iterator invalidation between AddrRanges). Use it if possible.
@@ -630,12 +631,12 @@ func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars userme
// truncatedAddrRangeSeq returns a copy of ars, but with the end truncated to
// at most address end on AddrRange arsit.Head(). It is used in vector I/O paths to
-// truncate usermem.AddrRangeSeq when errors occur.
+// truncate hostarch.AddrRangeSeq when errors occur.
//
// Preconditions:
// * !arsit.IsEmpty().
// * end <= arsit.Head().End.
-func truncatedAddrRangeSeq(ars, arsit usermem.AddrRangeSeq, end usermem.Addr) usermem.AddrRangeSeq {
+func truncatedAddrRangeSeq(ars, arsit hostarch.AddrRangeSeq, end hostarch.Addr) hostarch.AddrRangeSeq {
ar := arsit.Head()
if end <= ar.Start {
return ars.TakeFirst64(ars.NumBytes() - arsit.NumBytes())
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 120707429..a79ef9223 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -19,12 +19,12 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/usermem"
)
// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
@@ -139,7 +139,7 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
}
srcvseg := mm.vmas.FirstSegment()
dstpgap := mm2.pmas.FirstGap()
- var unmapAR usermem.AddrRange
+ var unmapAR hostarch.AddrRange
for srcpseg := mm.pmas.FirstSegment(); srcpseg.Ok(); srcpseg = srcpseg.NextSegment() {
pma := srcpseg.ValuePtr()
if !pma.private {
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
index 0cfd60f6c..28c5fead9 100644
--- a/pkg/sentry/mm/metadata.go
+++ b/pkg/sentry/mm/metadata.go
@@ -16,9 +16,9 @@ package mm
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Dumpability describes if and how core dumps should be created.
@@ -54,14 +54,14 @@ func (mm *MemoryManager) SetDumpability(d Dumpability) {
// ArgvStart returns the start of the application argument vector.
//
// There is no guarantee that this value is sensible w.r.t. ArgvEnd.
-func (mm *MemoryManager) ArgvStart() usermem.Addr {
+func (mm *MemoryManager) ArgvStart() hostarch.Addr {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
return mm.argv.Start
}
// SetArgvStart sets the start of the application argument vector.
-func (mm *MemoryManager) SetArgvStart(a usermem.Addr) {
+func (mm *MemoryManager) SetArgvStart(a hostarch.Addr) {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
mm.argv.Start = a
@@ -70,14 +70,14 @@ func (mm *MemoryManager) SetArgvStart(a usermem.Addr) {
// ArgvEnd returns the end of the application argument vector.
//
// There is no guarantee that this value is sensible w.r.t. ArgvStart.
-func (mm *MemoryManager) ArgvEnd() usermem.Addr {
+func (mm *MemoryManager) ArgvEnd() hostarch.Addr {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
return mm.argv.End
}
// SetArgvEnd sets the end of the application argument vector.
-func (mm *MemoryManager) SetArgvEnd(a usermem.Addr) {
+func (mm *MemoryManager) SetArgvEnd(a hostarch.Addr) {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
mm.argv.End = a
@@ -86,14 +86,14 @@ func (mm *MemoryManager) SetArgvEnd(a usermem.Addr) {
// EnvvStart returns the start of the application environment vector.
//
// There is no guarantee that this value is sensible w.r.t. EnvvEnd.
-func (mm *MemoryManager) EnvvStart() usermem.Addr {
+func (mm *MemoryManager) EnvvStart() hostarch.Addr {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
return mm.envv.Start
}
// SetEnvvStart sets the start of the application environment vector.
-func (mm *MemoryManager) SetEnvvStart(a usermem.Addr) {
+func (mm *MemoryManager) SetEnvvStart(a hostarch.Addr) {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
mm.envv.Start = a
@@ -102,14 +102,14 @@ func (mm *MemoryManager) SetEnvvStart(a usermem.Addr) {
// EnvvEnd returns the end of the application environment vector.
//
// There is no guarantee that this value is sensible w.r.t. EnvvStart.
-func (mm *MemoryManager) EnvvEnd() usermem.Addr {
+func (mm *MemoryManager) EnvvEnd() hostarch.Addr {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
return mm.envv.End
}
// SetEnvvEnd sets the end of the application environment vector.
-func (mm *MemoryManager) SetEnvvEnd(a usermem.Addr) {
+func (mm *MemoryManager) SetEnvvEnd(a hostarch.Addr) {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
mm.envv.End = a
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 92cc87d84..57969b26c 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -36,6 +36,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -43,7 +44,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// MemoryManager implements a virtual address space.
@@ -97,7 +97,7 @@ type MemoryManager struct {
// binary into the mm.
//
// brk is protected by mappingMu.
- brk usermem.AddrRange
+ brk hostarch.AddrRange
// usageAS is vmas.Span(), cached to accelerate RLIMIT_AS checks.
//
@@ -198,14 +198,14 @@ type MemoryManager struct {
// requirements apply to argv; we do not require that argv.WellFormed().
//
// argv is protected by metadataMu.
- argv usermem.AddrRange
+ argv hostarch.AddrRange
// envv is the application envv. This is set up by the loader and may be
// modified by prctl(PR_SET_MM_ENV_START/PR_SET_MM_ENV_END). No
// requirements apply to envv; we do not require that envv.WellFormed().
//
// envv is protected by metadataMu.
- envv usermem.AddrRange
+ envv hostarch.AddrRange
// auxv is the ELF's auxiliary vector.
//
@@ -268,20 +268,20 @@ type vma struct {
// realPerms are the memory permissions on this vma, as defined by the
// application.
- realPerms usermem.AccessType `state:".(int)"`
+ realPerms hostarch.AccessType `state:".(int)"`
// effectivePerms are the memory permissions on this vma which are
// actually used to control access.
//
// Invariant: effectivePerms == realPerms.Effective().
- effectivePerms usermem.AccessType `state:"manual"`
+ effectivePerms hostarch.AccessType `state:"manual"`
// maxPerms limits the set of permissions that may ever apply to this
// memory, as well as accesses for which usermem.IOOpts.IgnorePermissions
// is true (e.g. ptrace(PTRACE_POKEDATA)).
//
// Invariant: maxPerms == maxPerms.Effective().
- maxPerms usermem.AccessType `state:"manual"`
+ maxPerms hostarch.AccessType `state:"manual"`
// private is true if this is a MAP_PRIVATE mapping, such that writes to
// the mapping are propagated to a copy.
@@ -421,8 +421,8 @@ type pma struct {
off uint64
// translatePerms is the permissions returned by memmap.Mappable.Translate.
- // If private is true, translatePerms is usermem.AnyAccess.
- translatePerms usermem.AccessType
+ // If private is true, translatePerms is hostarch.AnyAccess.
+ translatePerms hostarch.AccessType
// effectivePerms is the permissions allowed for non-ignorePermissions
// accesses. maxPerms is the permissions allowed for ignorePermissions
@@ -432,8 +432,8 @@ type pma struct {
//
// These are stored in the pma so that the IO implementation can avoid
// iterating mm.vmas when pmas already exist.
- effectivePerms usermem.AccessType
- maxPerms usermem.AccessType
+ effectivePerms hostarch.AccessType
+ maxPerms hostarch.AccessType
// needCOW is true if writes to the mapping must be propagated to a copy.
needCOW bool
@@ -465,7 +465,7 @@ type privateRefs struct {
}
type invalidateArgs struct {
- ar usermem.AddrRange
+ ar hostarch.AddrRange
opts memmap.InvalidateOpts
}
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index bc53bd41e..1304b0a2f 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -51,7 +52,7 @@ func TestUsageASUpdates(t *testing.T) {
defer mm.DecUsers(ctx)
addr, err := mm.MMap(ctx, memmap.MMapOpts{
- Length: 2 * usermem.PageSize,
+ Length: 2 * hostarch.PageSize,
Private: true,
})
if err != nil {
@@ -62,7 +63,7 @@ func TestUsageASUpdates(t *testing.T) {
t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage)
}
- mm.MUnmap(ctx, addr, usermem.PageSize)
+ mm.MUnmap(ctx, addr, hostarch.PageSize)
realUsage = mm.realUsageAS()
if mm.usageAS != realUsage {
t.Fatalf("usageAS believes %v bytes are mapped; %v bytes are actually mapped", mm.usageAS, realUsage)
@@ -86,10 +87,10 @@ func TestDataASUpdates(t *testing.T) {
defer mm.DecUsers(ctx)
addr, err := mm.MMap(ctx, memmap.MMapOpts{
- Length: 3 * usermem.PageSize,
+ Length: 3 * hostarch.PageSize,
Private: true,
- Perms: usermem.Write,
- MaxPerms: usermem.AnyAccess,
+ Perms: hostarch.Write,
+ MaxPerms: hostarch.AnyAccess,
})
if err != nil {
t.Fatalf("MMap got err %v want nil", err)
@@ -102,19 +103,19 @@ func TestDataASUpdates(t *testing.T) {
t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS)
}
- mm.MUnmap(ctx, addr, usermem.PageSize)
+ mm.MUnmap(ctx, addr, hostarch.PageSize)
realDataAS = mm.realDataAS()
if mm.dataAS != realDataAS {
t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS)
}
- mm.MProtect(addr+usermem.PageSize, usermem.PageSize, usermem.Read, false)
+ mm.MProtect(addr+hostarch.PageSize, hostarch.PageSize, hostarch.Read, false)
realDataAS = mm.realDataAS()
if mm.dataAS != realDataAS {
t.Fatalf("dataAS believes %v bytes are mapped; %v bytes are actually mapped", mm.dataAS, realDataAS)
}
- mm.MRemap(ctx, addr+2*usermem.PageSize, usermem.PageSize, 2*usermem.PageSize, MRemapOpts{
+ mm.MRemap(ctx, addr+2*hostarch.PageSize, hostarch.PageSize, 2*hostarch.PageSize, MRemapOpts{
Move: MRemapMayMove,
})
realDataAS = mm.realDataAS()
@@ -133,7 +134,7 @@ func TestBrkDataLimitUpdates(t *testing.T) {
// Try to extend the brk by one page and expect doing so to fail.
oldBrk, _ := mm.Brk(ctx, 0)
- if newBrk, _ := mm.Brk(ctx, oldBrk+usermem.PageSize); newBrk != oldBrk {
+ if newBrk, _ := mm.Brk(ctx, oldBrk+hostarch.PageSize); newBrk != oldBrk {
t.Errorf("brk() increased data segment above RLIMIT_DATA (old brk = %#x, new brk = %#x", oldBrk, newBrk)
}
}
@@ -145,10 +146,10 @@ func TestIOAfterUnmap(t *testing.T) {
defer mm.DecUsers(ctx)
addr, err := mm.MMap(ctx, memmap.MMapOpts{
- Length: usermem.PageSize,
+ Length: hostarch.PageSize,
Private: true,
- Perms: usermem.Read,
- MaxPerms: usermem.AnyAccess,
+ Perms: hostarch.Read,
+ MaxPerms: hostarch.AnyAccess,
})
if err != nil {
t.Fatalf("MMap got err %v want nil", err)
@@ -164,7 +165,7 @@ func TestIOAfterUnmap(t *testing.T) {
t.Errorf("CopyIn got %d want 1", n)
}
- err = mm.MUnmap(ctx, addr, usermem.PageSize)
+ err = mm.MUnmap(ctx, addr, hostarch.PageSize)
if err != nil {
t.Fatalf("MUnmap got err %v want nil", err)
}
@@ -185,10 +186,10 @@ func TestIOAfterMProtect(t *testing.T) {
defer mm.DecUsers(ctx)
addr, err := mm.MMap(ctx, memmap.MMapOpts{
- Length: usermem.PageSize,
+ Length: hostarch.PageSize,
Private: true,
- Perms: usermem.ReadWrite,
- MaxPerms: usermem.AnyAccess,
+ Perms: hostarch.ReadWrite,
+ MaxPerms: hostarch.AnyAccess,
})
if err != nil {
t.Fatalf("MMap got err %v want nil", err)
@@ -204,7 +205,7 @@ func TestIOAfterMProtect(t *testing.T) {
t.Errorf("CopyOut got %d want 1", n)
}
- err = mm.MProtect(addr, usermem.PageSize, usermem.Read, false)
+ err = mm.MProtect(addr, hostarch.PageSize, hostarch.Read, false)
if err != nil {
t.Errorf("MProtect got err %v want nil", err)
}
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 7e5f7de64..5583f62b2 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -18,12 +18,12 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// existingPMAsLocked checks that pmas exist for all addresses in ar, and
@@ -34,7 +34,7 @@ import (
// Preconditions:
// * mm.activeMu must be locked.
// * ar.Length() != 0.
-func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator {
+func (mm *MemoryManager) existingPMAsLocked(ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -70,7 +70,7 @@ func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.Acc
// and support access of type (at, ignorePermissions).
//
// Preconditions: mm.activeMu must be locked.
-func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) bool {
+func (mm *MemoryManager) existingVecPMAsLocked(ars hostarch.AddrRangeSeq, at hostarch.AccessType, ignorePermissions bool, needInternalMappings bool) bool {
for ; !ars.IsEmpty(); ars = ars.Tail() {
if ar := ars.Head(); ar.Length() != 0 && !mm.existingPMAsLocked(ar, at, ignorePermissions, needInternalMappings).Ok() {
return false
@@ -98,7 +98,7 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user
// * vseg.Range().Contains(ar.Start).
// * vmas must exist for all addresses in ar, and support accesses of type at
// (i.e. permission checks must have been performed against vmas).
-func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
+func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, at hostarch.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -118,7 +118,7 @@ func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar
end = ar.End.RoundDown()
alignerr = syserror.EFAULT
}
- ar = usermem.AddrRange{ar.Start.RoundDown(), end}
+ ar = hostarch.AddrRange{ar.Start.RoundDown(), end}
pstart, pend, perr := mm.getPMAsInternalLocked(ctx, vseg, ar, at)
if pend.Start() <= ar.Start {
@@ -145,7 +145,7 @@ func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar
// * mm.activeMu must be locked for writing.
// * vmas must exist for all addresses in ars, and support accesses of type at
// (i.e. permission checks must have been performed against vmas).
-func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType) (usermem.AddrRangeSeq, error) {
+func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars hostarch.AddrRangeSeq, at hostarch.AccessType) (hostarch.AddrRangeSeq, error) {
for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
ar := arsit.Head()
if ar.Length() == 0 {
@@ -164,7 +164,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR
end = ar.End.RoundDown()
alignerr = syserror.EFAULT
}
- ar = usermem.AddrRange{ar.Start.RoundDown(), end}
+ ar = hostarch.AddrRange{ar.Start.RoundDown(), end}
_, pend, perr := mm.getPMAsInternalLocked(ctx, mm.vmas.FindSegment(ar.Start), ar, at)
if perr != nil {
@@ -191,7 +191,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR
//
// getPMAsInternalLocked is an implementation helper for getPMAsLocked and
// getVecPMAsLocked; other clients should call one of those instead.
-func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
+func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, at hostarch.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -245,7 +245,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
pseg, pgap = mm.pmas.Insert(pgap, allocAR, pma{
file: mf,
off: fr.Start,
- translatePerms: usermem.AnyAccess,
+ translatePerms: hostarch.AnyAccess,
effectivePerms: vma.effectivePerms,
maxPerms: vma.maxPerms,
// Since we just allocated this memory and have the
@@ -335,7 +335,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
// Neither of these cases has enough spatial locality to
// benefit from copying nearby pages, so if the vma is
// executable, only copy the pages required.
- var copyAR usermem.AddrRange
+ var copyAR hostarch.AddrRange
if vseg.ValuePtr().effectivePerms.Execute {
copyAR = pseg.Range().Intersect(ar)
} else {
@@ -366,7 +366,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
// Replace the pma with a copy in the part of the address
// range where copying was successful. This doesn't change
// RSS.
- copyAR.End = copyAR.Start + usermem.Addr(fr.Length())
+ copyAR.End = copyAR.Start + hostarch.Addr(fr.Length())
if copyAR != pseg.Range() {
pseg = mm.pmas.Isolate(pseg, copyAR)
pstart = pmaIterator{} // iterators invalidated
@@ -380,7 +380,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
mf.IncRef(fr)
oldpma.file = mf
oldpma.off = fr.Start
- oldpma.translatePerms = usermem.AnyAccess
+ oldpma.translatePerms = hostarch.AnyAccess
oldpma.effectivePerms = vma.effectivePerms
oldpma.maxPerms = vma.maxPerms
oldpma.needCOW = false
@@ -499,14 +499,14 @@ const (
// privateAllocUnit may reduce page faults by allowing fewer, larger pmas
// to be mapped, but may result in larger amounts of wasted memory in the
// presence of fragmentation. privateAllocUnit must be a power-of-2
- // multiple of usermem.PageSize.
- privateAllocUnit = usermem.HugePageSize
+ // multiple of hostarch.PageSize.
+ privateAllocUnit = hostarch.HugePageSize
privateAllocMask = privateAllocUnit - 1
)
-func privateAligned(ar usermem.AddrRange) usermem.AddrRange {
- aligned := usermem.AddrRange{ar.Start &^ privateAllocMask, ar.End}
+func privateAligned(ar hostarch.AddrRange) hostarch.AddrRange {
+ aligned := hostarch.AddrRange{ar.Start &^ privateAllocMask, ar.End}
if end := (ar.End + privateAllocMask) &^ privateAllocMask; end >= ar.End {
aligned.End = end
}
@@ -548,7 +548,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat
rseg := mm.privateRefs.refs.FindSegment(fr.Start)
if rseg.Ok() && rseg.Value() == 1 && fr.End <= rseg.End() {
pma.needCOW = false
- // pma.private => pma.translatePerms == usermem.AnyAccess
+ // pma.private => pma.translatePerms == hostarch.AnyAccess
vma := vseg.ValuePtr()
pma.effectivePerms = vma.effectivePerms
pma.maxPerms = vma.maxPerms
@@ -558,7 +558,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat
}
// Invalidate implements memmap.MappingSpace.Invalidate.
-func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
+func (mm *MemoryManager) Invalidate(ar hostarch.AddrRange, opts memmap.InvalidateOpts) {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -581,7 +581,7 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate
// * mm.activeMu must be locked for writing.
// * ar.Length() != 0.
// * ar must be page-aligned.
-func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) {
+func (mm *MemoryManager) invalidateLocked(ar hostarch.AddrRange, invalidatePrivate, invalidateShared bool) {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -627,7 +627,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
// Preconditions:
// * ar.Length() != 0.
// * ar must be page-aligned.
-func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) {
+func (mm *MemoryManager) Pin(ctx context.Context, ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool) ([]PinnedRange, error) {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -683,7 +683,7 @@ func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at userm
// PinnedRanges are returned by MemoryManager.Pin.
type PinnedRange struct {
// Source is the corresponding range of addresses.
- Source usermem.AddrRange
+ Source hostarch.AddrRange
// File is the mapped file.
File memmap.File
@@ -713,7 +713,7 @@ func Unpin(prs []PinnedRange) {
// * !oldAR.Overlaps(newAR).
// * mm.pmas.IsEmptyRange(newAR).
// * oldAR and newAR must be page-aligned.
-func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
+func (mm *MemoryManager) movePMAsLocked(oldAR, newAR hostarch.AddrRange) {
if checkInvariants {
if !oldAR.WellFormed() || oldAR.Length() == 0 || !oldAR.IsPageAligned() {
panic(fmt.Sprintf("invalid oldAR: %v", oldAR))
@@ -731,7 +731,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
}
type movedPMA struct {
- oldAR usermem.AddrRange
+ oldAR hostarch.AddrRange
pma pma
}
var movedPMAs []movedPMA
@@ -751,7 +751,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
pgap := mm.pmas.FindGap(newAR.Start)
for i := range movedPMAs {
mpma := &movedPMAs[i]
- pmaNewAR := usermem.AddrRange{mpma.oldAR.Start + off, mpma.oldAR.End + off}
+ pmaNewAR := hostarch.AddrRange{mpma.oldAR.Start + off, mpma.oldAR.End + off}
pgap = mm.pmas.Insert(pgap, pmaNewAR, mpma.pma).NextGap()
}
@@ -776,7 +776,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
//
// Postconditions: getPMAInternalMappingsLocked does not invalidate iterators
// into mm.pmas.
-func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) {
+func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar hostarch.AddrRange) (pmaGapIterator, error) {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -808,7 +808,7 @@ func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar userm
//
// Postconditions: getVecPMAInternalMappingsLocked does not invalidate iterators
// into mm.pmas.
-func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSeq) (usermem.AddrRangeSeq, error) {
+func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars hostarch.AddrRangeSeq) (hostarch.AddrRangeSeq, error) {
for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
ar := arsit.Head()
if ar.Length() == 0 {
@@ -829,7 +829,7 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe
// in ar.
// * ar.Length() != 0.
// * pseg.Range().Contains(ar.Start).
-func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq {
+func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar hostarch.AddrRange) safemem.BlockSeq {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -866,7 +866,7 @@ func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.Add
// * mm.activeMu must be locked.
// * Internal mappings must have been previously established for all addresses
// in ars.
-func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) safemem.BlockSeq {
+func (mm *MemoryManager) vecInternalMappingsLocked(ars hostarch.AddrRangeSeq) safemem.BlockSeq {
var ims []safemem.Block
for ; !ars.IsEmpty(); ars = ars.Tail() {
ar := ars.Head()
@@ -931,7 +931,7 @@ func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) {
// MemoryManager to reflect the insertion of a pma at ar.
//
// Preconditions: mm.activeMu must be locked for writing.
-func (mm *MemoryManager) addRSSLocked(ar usermem.AddrRange) {
+func (mm *MemoryManager) addRSSLocked(ar hostarch.AddrRange) {
mm.curRSS += uint64(ar.Length())
if mm.curRSS > mm.maxRSS {
mm.maxRSS = mm.curRSS
@@ -942,19 +942,19 @@ func (mm *MemoryManager) addRSSLocked(ar usermem.AddrRange) {
// reflect the removal of a pma at ar.
//
// Preconditions: mm.activeMu must be locked for writing.
-func (mm *MemoryManager) removeRSSLocked(ar usermem.AddrRange) {
+func (mm *MemoryManager) removeRSSLocked(ar hostarch.AddrRange) {
mm.curRSS -= uint64(ar.Length())
}
// pmaSetFunctions implements segment.Functions for pmaSet.
type pmaSetFunctions struct{}
-func (pmaSetFunctions) MinKey() usermem.Addr {
+func (pmaSetFunctions) MinKey() hostarch.Addr {
return 0
}
-func (pmaSetFunctions) MaxKey() usermem.Addr {
- return ^usermem.Addr(0)
+func (pmaSetFunctions) MaxKey() hostarch.Addr {
+ return ^hostarch.Addr(0)
}
func (pmaSetFunctions) ClearValue(pma *pma) {
@@ -962,7 +962,7 @@ func (pmaSetFunctions) ClearValue(pma *pma) {
pma.internalMappings = safemem.BlockSeq{}
}
-func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRange, pma2 pma) (pma, bool) {
+func (pmaSetFunctions) Merge(ar1 hostarch.AddrRange, pma1 pma, ar2 hostarch.AddrRange, pma2 pma) (pma, bool) {
if pma1.file != pma2.file ||
pma1.off+uint64(ar1.Length()) != pma2.off ||
pma1.translatePerms != pma2.translatePerms ||
@@ -980,7 +980,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa
return pma1, true
}
-func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (pma, pma) {
+func (pmaSetFunctions) Split(ar hostarch.AddrRange, p pma, split hostarch.Addr) (pma, pma) {
newlen1 := uint64(split - ar.Start)
p2 := p
p2.off += newlen1
@@ -997,7 +997,7 @@ func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (p
// Preconditions:
// * mm.activeMu must be locked.
// * addr <= pgap.Start().
-func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr usermem.Addr, pgap pmaGapIterator) pmaIterator {
+func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr hostarch.Addr, pgap pmaGapIterator) pmaIterator {
if checkInvariants {
if !pgap.Ok() {
panic("terminal pma iterator")
@@ -1045,7 +1045,7 @@ func (pseg pmaIterator) fileRange() memmap.FileRange {
// Preconditions:
// * pseg.Range().IsSupersetOf(ar).
// * ar.Length != 0.
-func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
+func (pseg pmaIterator) fileRangeOf(ar hostarch.AddrRange) memmap.FileRange {
if checkInvariants {
if !pseg.Ok() {
panic("terminal pma iterator")
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index 73bfbea49..f1440e884 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -19,9 +19,9 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -29,7 +29,7 @@ const (
// include/linux/kdev_t.h:MINORBITS
devMinorBits = 20
- vsyscallEnd = usermem.Addr(0xffffffffff601000)
+ vsyscallEnd = hostarch.Addr(0xffffffffff601000)
vsyscallMapsEntry = "ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0 [vsyscall]\n"
vsyscallSmapsEntry = vsyscallMapsEntry +
"Size: 4 kB\n" +
@@ -62,7 +62,7 @@ func (mm *MemoryManager) NeedsUpdate(generation int64) bool {
func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer) {
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
- var start usermem.Addr
+ var start hostarch.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
mm.appendVMAMapsEntryLocked(ctx, vseg, buf)
@@ -88,9 +88,9 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
var data []seqfile.SeqData
- var start usermem.Addr
+ var start hostarch.Addr
if handle != nil {
- start = *handle.(*usermem.Addr)
+ start = *handle.(*hostarch.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
vmaAddr := vseg.End()
@@ -177,7 +177,7 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI
func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffer) {
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
- var start usermem.Addr
+ var start hostarch.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf)
@@ -196,9 +196,9 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
var data []seqfile.SeqData
- var start usermem.Addr
+ var start hostarch.Addr
if handle != nil {
- start = *handle.(*usermem.Addr)
+ start = *handle.(*hostarch.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
vmaAddr := vseg.End()
@@ -279,8 +279,8 @@ func (mm *MemoryManager) vmaSmapsEntryIntoLocked(ctx context.Context, vseg vmaIt
// Swap is not implemented.
fmt.Fprintf(b, "Swap: %8d kB\n", 0)
fmt.Fprintf(b, "SwapPss: %8d kB\n", 0)
- fmt.Fprintf(b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024)
- fmt.Fprintf(b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024)
+ fmt.Fprintf(b, "KernelPageSize: %8d kB\n", hostarch.PageSize/1024)
+ fmt.Fprintf(b, "MMUPageSize: %8d kB\n", hostarch.PageSize/1024)
locked := rss
if vma.mlockMode == memmap.MLockNone {
locked = 0
diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go
index 6432731d4..3130be80c 100644
--- a/pkg/sentry/mm/shm.go
+++ b/pkg/sentry/mm/shm.go
@@ -16,13 +16,13 @@ package mm
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/shm"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// DetachShm unmaps a sysv shared memory segment.
-func (mm *MemoryManager) DetachShm(ctx context.Context, addr usermem.Addr) error {
+func (mm *MemoryManager) DetachShm(ctx context.Context, addr hostarch.Addr) error {
if addr != addr.RoundDown() {
// "... shmaddr is not aligned on a page boundary." - man shmdt(2)
return syserror.EINVAL
@@ -52,7 +52,7 @@ func (mm *MemoryManager) DetachShm(ctx context.Context, addr usermem.Addr) error
}
// Remove all vmas that could have been created by the same attach.
- end := addr + usermem.Addr(detached.EffectiveSize())
+ end := addr + hostarch.Addr(detached.EffectiveSize())
for vseg.Ok() && vseg.End() <= end {
vma := vseg.ValuePtr()
if vma.mappable == detached && uint64(vseg.Start()-addr) == vma.off {
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index 48d8b6a2b..e748b7ff8 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -16,11 +16,11 @@ package mm
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// SpecialMappable implements memmap.MappingIdentity and memmap.Mappable with
@@ -77,21 +77,21 @@ func (m *SpecialMappable) Msync(ctx context.Context, mr memmap.MappableRange) er
}
// AddMapping implements memmap.Mappable.AddMapping.
-func (*SpecialMappable) AddMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) error {
+func (*SpecialMappable) AddMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, uint64, bool) error {
return nil
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-func (*SpecialMappable) RemoveMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, uint64, bool) {
+func (*SpecialMappable) RemoveMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, uint64, bool) {
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-func (*SpecialMappable) CopyMapping(context.Context, memmap.MappingSpace, usermem.AddrRange, usermem.AddrRange, uint64, bool) error {
+func (*SpecialMappable) CopyMapping(context.Context, memmap.MappingSpace, hostarch.AddrRange, hostarch.AddrRange, uint64, bool) error {
return nil
}
// Translate implements memmap.Mappable.Translate.
-func (m *SpecialMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (m *SpecialMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
var err error
if required.End > m.fr.Length() {
err = &memmap.BusError{syserror.EFAULT}
@@ -102,7 +102,7 @@ func (m *SpecialMappable) Translate(ctx context.Context, required, optional memm
Source: source,
File: m.mfp.MemoryFile(),
Offset: m.fr.Start + source.Start,
- Perms: usermem.AnyAccess,
+ Perms: hostarch.AnyAccess,
},
}, err
}
@@ -146,7 +146,7 @@ func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*Spec
if length == 0 {
return nil, syserror.EINVAL
}
- alignedLen, ok := usermem.Addr(length).RoundUp()
+ alignedLen, ok := hostarch.Addr(length).RoundUp()
if !ok {
return nil, syserror.EINVAL
}
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 69e37330b..7ad6b7c21 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -21,20 +21,20 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// HandleUserFault handles an application page fault. sp is the faulting
// application thread's stack pointer.
//
// Preconditions: mm.as != nil.
-func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr usermem.Addr, at usermem.AccessType, sp usermem.Addr) error {
- ar, ok := addr.RoundDown().ToRange(usermem.PageSize)
+func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr hostarch.Addr, at hostarch.AccessType, sp hostarch.Addr) error {
+ ar, ok := addr.RoundDown().ToRange(hostarch.PageSize)
if !ok {
return syserror.EFAULT
}
@@ -72,11 +72,11 @@ func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr usermem.Addr,
}
// MMap establishes a memory mapping.
-func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error) {
+func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostarch.Addr, error) {
if opts.Length == 0 {
return 0, syserror.EINVAL
}
- length, ok := usermem.Addr(opts.Length).RoundUp()
+ length, ok := hostarch.Addr(opts.Length).RoundUp()
if !ok {
return 0, syserror.ENOMEM
}
@@ -84,7 +84,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme
if opts.Mappable != nil {
// Offset must be aligned.
- if usermem.Addr(opts.Offset).RoundDown() != usermem.Addr(opts.Offset) {
+ if hostarch.Addr(opts.Offset).RoundDown() != hostarch.Addr(opts.Offset) {
return 0, syserror.EINVAL
}
// Offset + length must not overflow.
@@ -157,7 +157,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme
// Preconditions:
// * mm.mappingMu must be locked.
// * vseg.Range().IsSupersetOf(ar).
-func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) {
+func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, precommit bool) {
if !vseg.ValuePtr().effectivePerms.Any() {
// Linux doesn't populate inaccessible pages. See
// mm/gup.c:populate_vma_page_range.
@@ -175,7 +175,7 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar u
}
// Ensure that we have usable pmas.
- pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess)
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, hostarch.NoAccess)
if err != nil {
// mm/util.c:vm_mmap_pgoff() ignores the error, if any, from
// mm/gup.c:mm_populate(). If it matters, we'll get it again when
@@ -203,7 +203,7 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar u
// * vseg.Range().IsSupersetOf(ar).
//
// Postconditions: mm.mappingMu will be unlocked.
-func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) {
+func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, precommit bool) {
// See populateVMA above for commentary.
if !vseg.ValuePtr().effectivePerms.Any() {
mm.mappingMu.Unlock()
@@ -221,7 +221,7 @@ func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaItera
// mm.mappingMu doesn't need to be write-locked for getPMAsLocked, and it
// isn't needed at all for mapASLocked.
mm.mappingMu.DowngradeLock()
- pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, usermem.NoAccess)
+ pseg, _, err := mm.getPMAsLocked(ctx, vseg, ar, hostarch.NoAccess)
mm.mappingMu.RUnlock()
if err != nil {
mm.activeMu.Unlock()
@@ -234,7 +234,7 @@ func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaItera
}
// MapStack allocates the initial process stack.
-func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error) {
+func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, error) {
// maxStackSize is the maximum supported process stack size in bytes.
//
// This limit exists because stack growing isn't implemented, so the entire
@@ -242,7 +242,7 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error
const maxStackSize = 128 << 20
stackSize := limits.FromContext(ctx).Get(limits.Stack)
- r, ok := usermem.Addr(stackSize.Cur).RoundUp()
+ r, ok := hostarch.Addr(stackSize.Cur).RoundUp()
sz := uint64(r)
if !ok {
// RLIM_INFINITY rounds up to 0.
@@ -251,16 +251,16 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error
ctx.Warningf("Capping stack size from RLIMIT_STACK of %v down to %v.", sz, maxStackSize)
sz = maxStackSize
} else if sz == 0 {
- return usermem.AddrRange{}, syserror.ENOMEM
+ return hostarch.AddrRange{}, syserror.ENOMEM
}
- szaddr := usermem.Addr(sz)
+ szaddr := hostarch.Addr(sz)
ctx.Debugf("Allocating stack with size of %v bytes", sz)
// Determine the stack's desired location. Unlike Linux, address
// randomization can't be disabled.
- stackEnd := mm.layout.MaxAddr - usermem.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown()
+ stackEnd := mm.layout.MaxAddr - hostarch.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown()
if stackEnd < szaddr {
- return usermem.AddrRange{}, syserror.ENOMEM
+ return hostarch.AddrRange{}, syserror.ENOMEM
}
stackStart := stackEnd - szaddr
mm.mappingMu.Lock()
@@ -268,8 +268,8 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error
_, ar, err := mm.createVMALocked(ctx, memmap.MMapOpts{
Length: sz,
Addr: stackStart,
- Perms: usermem.ReadWrite,
- MaxPerms: usermem.AnyAccess,
+ Perms: hostarch.ReadWrite,
+ MaxPerms: hostarch.AnyAccess,
Private: true,
GrowsDown: true,
MLockMode: mm.defMLockMode,
@@ -279,14 +279,14 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (usermem.AddrRange, error
}
// MUnmap implements the semantics of Linux's munmap(2).
-func (mm *MemoryManager) MUnmap(ctx context.Context, addr usermem.Addr, length uint64) error {
+func (mm *MemoryManager) MUnmap(ctx context.Context, addr hostarch.Addr, length uint64) error {
if addr != addr.RoundDown() {
return syserror.EINVAL
}
if length == 0 {
return syserror.EINVAL
}
- la, ok := usermem.Addr(length).RoundUp()
+ la, ok := hostarch.Addr(length).RoundUp()
if !ok {
return syserror.EINVAL
}
@@ -308,7 +308,7 @@ type MRemapOpts struct {
// NewAddr is the new address for the remapping. NewAddr is ignored unless
// Move is MMRemapMustMove.
- NewAddr usermem.Addr
+ NewAddr hostarch.Addr
}
// MRemapMoveMode controls MRemap's moving behavior.
@@ -328,7 +328,7 @@ const (
)
// MRemap implements the semantics of Linux's mremap(2).
-func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (usermem.Addr, error) {
+func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldSize uint64, newSize uint64, opts MRemapOpts) (hostarch.Addr, error) {
// "Note that old_address has to be page aligned." - mremap(2)
if oldAddr.RoundDown() != oldAddr {
return 0, syserror.EINVAL
@@ -336,9 +336,9 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
// Linux treats an old_size that rounds up to 0 as 0, which is otherwise a
// valid size. However, new_size can't be 0 after rounding.
- oldSizeAddr, _ := usermem.Addr(oldSize).RoundUp()
+ oldSizeAddr, _ := hostarch.Addr(oldSize).RoundUp()
oldSize = uint64(oldSizeAddr)
- newSizeAddr, ok := usermem.Addr(newSize).RoundUp()
+ newSizeAddr, ok := hostarch.Addr(newSize).RoundUp()
if !ok || newSizeAddr == 0 {
return 0, syserror.EINVAL
}
@@ -392,8 +392,8 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
if newSize < oldSize {
// If oldAddr+oldSize didn't overflow, oldAddr+newSize can't
// either.
- newEnd := oldAddr + usermem.Addr(newSize)
- mm.unmapLocked(ctx, usermem.AddrRange{newEnd, oldEnd})
+ newEnd := oldAddr + hostarch.Addr(newSize)
+ mm.unmapLocked(ctx, hostarch.AddrRange{newEnd, oldEnd})
}
return oldAddr, nil
}
@@ -438,7 +438,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
}
// Find a location for the new mapping.
- var newAR usermem.AddrRange
+ var newAR hostarch.AddrRange
switch opts.Move {
case MRemapMayMove:
newAddr, err := mm.findAvailableLocked(newSize, findAvailableOpts{})
@@ -457,7 +457,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
if !ok {
return 0, syserror.EINVAL
}
- if (usermem.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) {
+ if (hostarch.AddrRange{oldAddr, oldEnd}).Overlaps(newAR) {
return 0, syserror.EINVAL
}
@@ -479,8 +479,8 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
// correct: compare Linux's mm/mremap.c:mremap_to() => do_munmap(),
// vma_to_resize().
if newSize < oldSize {
- oldNewEnd := oldAddr + usermem.Addr(newSize)
- mm.unmapLocked(ctx, usermem.AddrRange{oldNewEnd, oldEnd})
+ oldNewEnd := oldAddr + hostarch.Addr(newSize)
+ mm.unmapLocked(ctx, hostarch.AddrRange{oldNewEnd, oldEnd})
oldEnd = oldNewEnd
}
@@ -488,7 +488,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
vseg = mm.vmas.FindSegment(oldAddr)
}
- oldAR := usermem.AddrRange{oldAddr, oldEnd}
+ oldAR := hostarch.AddrRange{oldAddr, oldEnd}
// Check that oldEnd maps to the same vma as oldAddr.
if vseg.End() < oldEnd {
@@ -588,14 +588,14 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
}
// MProtect implements the semantics of Linux's mprotect(2).
-func (mm *MemoryManager) MProtect(addr usermem.Addr, length uint64, realPerms usermem.AccessType, growsDown bool) error {
+func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms hostarch.AccessType, growsDown bool) error {
if addr.RoundDown() != addr {
return syserror.EINVAL
}
if length == 0 {
return nil
}
- rlength, ok := usermem.Addr(length).RoundUp()
+ rlength, ok := hostarch.Addr(length).RoundUp()
if !ok {
return syserror.ENOMEM
}
@@ -692,19 +692,19 @@ func (mm *MemoryManager) MProtect(addr usermem.Addr, length uint64, realPerms us
}
// BrkSetup sets mm's brk address to addr and its brk size to 0.
-func (mm *MemoryManager) BrkSetup(ctx context.Context, addr usermem.Addr) {
+func (mm *MemoryManager) BrkSetup(ctx context.Context, addr hostarch.Addr) {
mm.mappingMu.Lock()
defer mm.mappingMu.Unlock()
// Unmap the existing brk.
if mm.brk.Length() != 0 {
mm.unmapLocked(ctx, mm.brk)
}
- mm.brk = usermem.AddrRange{addr, addr}
+ mm.brk = hostarch.AddrRange{addr, addr}
}
// Brk implements the semantics of Linux's brk(2), except that it returns an
// error on failure.
-func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Addr, error) {
+func (mm *MemoryManager) Brk(ctx context.Context, addr hostarch.Addr) (hostarch.Addr, error) {
mm.mappingMu.Lock()
// Can't defer mm.mappingMu.Unlock(); see below.
@@ -741,8 +741,8 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Ad
Fixed: true,
// Compare Linux's
// arch/x86/include/asm/page_types.h:VM_DATA_DEFAULT_FLAGS.
- Perms: usermem.ReadWrite,
- MaxPerms: usermem.AnyAccess,
+ Perms: hostarch.ReadWrite,
+ MaxPerms: hostarch.AnyAccess,
Private: true,
// Linux: mm/mmap.c:sys_brk() => do_brk_flags() includes
// mm->def_flags.
@@ -762,7 +762,7 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Ad
}
case newbrkpg < oldbrkpg:
- mm.unmapLocked(ctx, usermem.AddrRange{newbrkpg, oldbrkpg})
+ mm.unmapLocked(ctx, hostarch.AddrRange{newbrkpg, oldbrkpg})
fallthrough
default:
@@ -775,9 +775,9 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr usermem.Addr) (usermem.Ad
// MLock implements the semantics of Linux's mlock()/mlock2()/munlock(),
// depending on mode.
-func (mm *MemoryManager) MLock(ctx context.Context, addr usermem.Addr, length uint64, mode memmap.MLockMode) error {
+func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length uint64, mode memmap.MLockMode) error {
// Linux allows this to overflow.
- la, _ := usermem.Addr(length + addr.PageOffset()).RoundUp()
+ la, _ := hostarch.Addr(length + addr.PageOffset()).RoundUp()
ar, ok := addr.RoundDown().ToRange(uint64(la))
if !ok {
return syserror.EINVAL
@@ -850,7 +850,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr usermem.Addr, length ui
mm.mappingMu.RUnlock()
return syserror.ENOMEM
}
- _, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), usermem.NoAccess)
+ _, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), hostarch.NoAccess)
if err != nil {
mm.activeMu.Unlock()
mm.mappingMu.RUnlock()
@@ -945,7 +945,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error
mm.mappingMu.DowngradeLock()
for vseg := mm.vmas.FirstSegment(); vseg.Ok(); vseg = vseg.NextSegment() {
if vseg.ValuePtr().effectivePerms.Any() {
- mm.getPMAsLocked(ctx, vseg, vseg.Range(), usermem.NoAccess)
+ mm.getPMAsLocked(ctx, vseg, vseg.Range(), hostarch.NoAccess)
}
}
@@ -965,7 +965,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error
}
// NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR).
-func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (linux.NumaPolicy, uint64, error) {
+func (mm *MemoryManager) NumaPolicy(addr hostarch.Addr) (linux.NumaPolicy, uint64, error) {
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
vseg := mm.vmas.FindSegment(addr)
@@ -977,12 +977,12 @@ func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (linux.NumaPolicy, uint64
}
// SetNumaPolicy implements the semantics of Linux's mbind().
-func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error {
+func (mm *MemoryManager) SetNumaPolicy(addr hostarch.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error {
if !addr.IsPageAligned() {
return syserror.EINVAL
}
// Linux allows this to overflow.
- la, _ := usermem.Addr(length).RoundUp()
+ la, _ := hostarch.Addr(length).RoundUp()
ar, ok := addr.ToRange(uint64(la))
if !ok {
return syserror.EINVAL
@@ -1018,7 +1018,7 @@ func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy
}
// SetDontFork implements the semantics of madvise MADV_DONTFORK.
-func (mm *MemoryManager) SetDontFork(addr usermem.Addr, length uint64, dontfork bool) error {
+func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork bool) error {
ar, ok := addr.ToRange(length)
if !ok {
return syserror.EINVAL
@@ -1044,7 +1044,7 @@ func (mm *MemoryManager) SetDontFork(addr usermem.Addr, length uint64, dontfork
}
// Decommit implements the semantics of Linux's madvise(MADV_DONTNEED).
-func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error {
+func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error {
ar, ok := addr.ToRange(length)
if !ok {
return syserror.EINVAL
@@ -1112,14 +1112,14 @@ type MSyncOpts struct {
}
// MSync implements the semantics of Linux's msync().
-func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length uint64, opts MSyncOpts) error {
+func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length uint64, opts MSyncOpts) error {
if addr != addr.RoundDown() {
return syserror.EINVAL
}
if length == 0 {
return nil
}
- la, ok := usermem.Addr(length).RoundUp()
+ la, ok := hostarch.Addr(length).RoundUp()
if !ok {
return syserror.ENOMEM
}
@@ -1188,7 +1188,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length ui
}
// GetSharedFutexKey is used by kernel.Task.GetSharedKey.
-func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr usermem.Addr) (futex.Key, error) {
+func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr hostarch.Addr) (futex.Key, error) {
ar, ok := addr.ToRange(4) // sizeof(int32).
if !ok {
return futex.Key{}, syserror.EFAULT
@@ -1196,7 +1196,7 @@ func (mm *MemoryManager) GetSharedFutexKey(ctx context.Context, addr usermem.Add
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
- vseg, _, err := mm.getVMAsLocked(ctx, ar, usermem.Read, false)
+ vseg, _, err := mm.getVMAsLocked(ctx, ar, hostarch.Read, false)
if err != nil {
return futex.Key{}, err
}
@@ -1230,7 +1230,7 @@ func (mm *MemoryManager) VirtualMemorySize() uint64 {
// VirtualMemorySizeRange returns the combined length in bytes of all mappings
// in ar in mm.
-func (mm *MemoryManager) VirtualMemorySizeRange(ar usermem.AddrRange) uint64 {
+func (mm *MemoryManager) VirtualMemorySizeRange(ar hostarch.AddrRange) uint64 {
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
return uint64(mm.vmas.SpanRange(ar))
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index b8df72813..0d019e41d 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -19,18 +19,18 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Preconditions:
// * mm.mappingMu must be locked for writing.
// * opts must be valid as defined by the checks in MMap.
-func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, usermem.AddrRange, error) {
+func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, hostarch.AddrRange, error) {
if opts.MaxPerms != opts.MaxPerms.Effective() {
panic(fmt.Sprintf("Non-effective MaxPerms %s cannot be enforced", opts.MaxPerms))
}
@@ -47,7 +47,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp
if opts.Force && opts.Unmap && opts.Fixed {
addr = opts.Addr
} else {
- return vmaIterator{}, usermem.AddrRange{}, err
+ return vmaIterator{}, hostarch.AddrRange{}, err
}
}
ar, _ := addr.ToRange(opts.Length)
@@ -58,7 +58,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp
newUsageAS -= uint64(mm.vmas.SpanRange(ar))
}
if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS {
- return vmaIterator{}, usermem.AddrRange{}, syserror.ENOMEM
+ return vmaIterator{}, hostarch.AddrRange{}, syserror.ENOMEM
}
if opts.MLockMode != memmap.MLockNone {
@@ -66,14 +66,14 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp
if creds := auth.CredentialsFromContext(ctx); !creds.HasCapabilityIn(linux.CAP_IPC_LOCK, creds.UserNamespace.Root()) {
mlockLimit := limits.FromContext(ctx).Get(limits.MemoryLocked).Cur
if mlockLimit == 0 {
- return vmaIterator{}, usermem.AddrRange{}, syserror.EPERM
+ return vmaIterator{}, hostarch.AddrRange{}, syserror.EPERM
}
newLockedAS := mm.lockedAS + opts.Length
if opts.Unmap {
newLockedAS -= mm.mlockedBytesRangeLocked(ar)
}
if newLockedAS > mlockLimit {
- return vmaIterator{}, usermem.AddrRange{}, syserror.EAGAIN
+ return vmaIterator{}, hostarch.AddrRange{}, syserror.EAGAIN
}
}
}
@@ -93,7 +93,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp
// The expression for writable is vma.canWriteMappableLocked(), but we
// don't yet have a vma.
if err := opts.Mappable.AddMapping(ctx, mm, ar, opts.Offset, !opts.Private && opts.MaxPerms.Write); err != nil {
- return vmaIterator{}, usermem.AddrRange{}, err
+ return vmaIterator{}, hostarch.AddrRange{}, err
}
}
@@ -137,7 +137,7 @@ type findAvailableOpts struct {
//
// - Unmap allows existing guard pages in the returned range.
- Addr usermem.Addr
+ Addr hostarch.Addr
Fixed bool
Unmap bool
Map32Bit bool
@@ -153,13 +153,13 @@ const (
// findAvailableLocked finds an allocatable range.
//
// Preconditions: mm.mappingMu must be locked.
-func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOpts) (usermem.Addr, error) {
+func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOpts) (hostarch.Addr, error) {
if opts.Fixed {
opts.Map32Bit = false
}
allowedAR := mm.applicationAddrRange()
if opts.Map32Bit {
- allowedAR = allowedAR.Intersect(usermem.AddrRange{map32Start, map32End})
+ allowedAR = allowedAR.Intersect(hostarch.AddrRange{map32Start, map32End})
}
// Does the provided suggestion work?
@@ -181,33 +181,33 @@ func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOp
}
// Prefer hugepage alignment if a hugepage or more is requested.
- alignment := uint64(usermem.PageSize)
- if length >= usermem.HugePageSize {
- alignment = usermem.HugePageSize
+ alignment := uint64(hostarch.PageSize)
+ if length >= hostarch.HugePageSize {
+ alignment = hostarch.HugePageSize
}
if opts.Map32Bit {
return mm.findLowestAvailableLocked(length, alignment, allowedAR)
}
if mm.layout.DefaultDirection == arch.MmapBottomUp {
- return mm.findLowestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.BottomUpBase, mm.layout.MaxAddr})
+ return mm.findLowestAvailableLocked(length, alignment, hostarch.AddrRange{mm.layout.BottomUpBase, mm.layout.MaxAddr})
}
- return mm.findHighestAvailableLocked(length, alignment, usermem.AddrRange{mm.layout.MinAddr, mm.layout.TopDownBase})
+ return mm.findHighestAvailableLocked(length, alignment, hostarch.AddrRange{mm.layout.MinAddr, mm.layout.TopDownBase})
}
-func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange {
- return usermem.AddrRange{mm.layout.MinAddr, mm.layout.MaxAddr}
+func (mm *MemoryManager) applicationAddrRange() hostarch.AddrRange {
+ return hostarch.AddrRange{mm.layout.MinAddr, mm.layout.MaxAddr}
}
// Preconditions: mm.mappingMu must be locked.
-func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
- for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) {
+func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds hostarch.AddrRange) (hostarch.Addr, error) {
+ for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(hostarch.Addr(length)) {
if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
// Can we shift up to match the alignment?
if offset := uint64(gr.Start) % alignment; offset != 0 {
if uint64(gr.Length()) >= length+alignment-offset {
// Yes, we're aligned.
- return gr.Start + usermem.Addr(alignment-offset), nil
+ return gr.Start + hostarch.Addr(alignment-offset), nil
}
}
@@ -219,15 +219,15 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou
}
// Preconditions: mm.mappingMu must be locked.
-func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
- for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) {
+func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds hostarch.AddrRange) (hostarch.Addr, error) {
+ for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(hostarch.Addr(length)) {
if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
// Can we shift down to match the alignment?
- start := gr.End - usermem.Addr(length)
+ start := gr.End - hostarch.Addr(length)
if offset := uint64(start) % alignment; offset != 0 {
- if gr.Start <= start-usermem.Addr(offset) {
+ if gr.Start <= start-hostarch.Addr(offset) {
// Yes, we're aligned.
- return start - usermem.Addr(offset), nil
+ return start - hostarch.Addr(offset), nil
}
}
@@ -239,7 +239,7 @@ func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bo
}
// Preconditions: mm.mappingMu must be locked.
-func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 {
+func (mm *MemoryManager) mlockedBytesRangeLocked(ar hostarch.AddrRange) uint64 {
var total uint64
for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() {
if vseg.ValuePtr().mlockMode != memmap.MLockNone {
@@ -264,7 +264,7 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 {
// Preconditions:
// * mm.mappingMu must be locked for reading; it may be temporarily unlocked.
// * ar.Length() != 0.
-func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) {
+func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar hostarch.AddrRange, at hostarch.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -320,7 +320,7 @@ func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange
// temporarily unlocked.
//
// Postconditions: ars is not mutated.
-func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType, ignorePermissions bool) (usermem.AddrRangeSeq, error) {
+func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars hostarch.AddrRangeSeq, at hostarch.AccessType, ignorePermissions bool) (hostarch.AddrRangeSeq, error) {
for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() {
ar := arsit.Head()
if ar.Length() == 0 {
@@ -339,7 +339,7 @@ func (mm *MemoryManager) getVecVMAsLocked(ctx context.Context, ars usermem.AddrR
//
// guardBytes is equivalent to Linux's stack_guard_gap after upstream
// 1be7107fbe18 "mm: larger stack guard gap, between vmas".
-const guardBytes = 256 * usermem.PageSize
+const guardBytes = 256 * hostarch.PageSize
// unmapLocked unmaps all addresses in ar and returns the resulting gap in
// mm.vmas.
@@ -348,7 +348,7 @@ const guardBytes = 256 * usermem.PageSize
// * mm.mappingMu must be locked for writing.
// * ar.Length() != 0.
// * ar must be page-aligned.
-func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
+func (mm *MemoryManager) unmapLocked(ctx context.Context, ar hostarch.AddrRange) vmaGapIterator {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -369,7 +369,7 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange)
// * mm.mappingMu must be locked for writing.
// * ar.Length() != 0.
// * ar must be page-aligned.
-func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
+func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar hostarch.AddrRange) vmaGapIterator {
if checkInvariants {
if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
@@ -426,12 +426,12 @@ func (vma *vma) isPrivateDataLocked() bool {
// vmaSetFunctions implements segment.Functions for vmaSet.
type vmaSetFunctions struct{}
-func (vmaSetFunctions) MinKey() usermem.Addr {
+func (vmaSetFunctions) MinKey() hostarch.Addr {
return 0
}
-func (vmaSetFunctions) MaxKey() usermem.Addr {
- return ^usermem.Addr(0)
+func (vmaSetFunctions) MaxKey() hostarch.Addr {
+ return ^hostarch.Addr(0)
}
func (vmaSetFunctions) ClearValue(vma *vma) {
@@ -440,7 +440,7 @@ func (vmaSetFunctions) ClearValue(vma *vma) {
vma.hint = ""
}
-func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRange, vma2 vma) (vma, bool) {
+func (vmaSetFunctions) Merge(ar1 hostarch.AddrRange, vma1 vma, ar2 hostarch.AddrRange, vma2 vma) (vma, bool) {
if vma1.mappable != vma2.mappable ||
(vma1.mappable != nil && vma1.off+uint64(ar1.Length()) != vma2.off) ||
vma1.realPerms != vma2.realPerms ||
@@ -462,7 +462,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa
return vma1, true
}
-func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (vma, vma) {
+func (vmaSetFunctions) Split(ar hostarch.AddrRange, v vma, split hostarch.Addr) (vma, vma) {
v2 := v
if v2.mappable != nil {
v2.off += uint64(split - ar.Start)
@@ -476,7 +476,7 @@ func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (v
// Preconditions:
// * vseg.ValuePtr().mappable != nil.
// * vseg.Range().Contains(addr).
-func (vseg vmaIterator) mappableOffsetAt(addr usermem.Addr) uint64 {
+func (vseg vmaIterator) mappableOffsetAt(addr hostarch.Addr) uint64 {
if checkInvariants {
if !vseg.Ok() {
panic("terminal vma iterator")
@@ -503,7 +503,7 @@ func (vseg vmaIterator) mappableRange() memmap.MappableRange {
// * vseg.ValuePtr().mappable != nil.
// * vseg.Range().IsSupersetOf(ar).
// * ar.Length() != 0.
-func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRange {
+func (vseg vmaIterator) mappableRangeOf(ar hostarch.AddrRange) memmap.MappableRange {
if checkInvariants {
if !vseg.Ok() {
panic("terminal vma iterator")
@@ -528,7 +528,7 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan
// * vseg.ValuePtr().mappable != nil.
// * vseg.mappableRange().IsSupersetOf(mr).
// * mr.Length() != 0.
-func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
+func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) hostarch.AddrRange {
if checkInvariants {
if !vseg.Ok() {
panic("terminal vma iterator")
@@ -546,7 +546,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
vma := vseg.ValuePtr()
vstart := vseg.Start()
- return usermem.AddrRange{vstart + usermem.Addr(mr.Start-vma.off), vstart + usermem.Addr(mr.End-vma.off)}
+ return hostarch.AddrRange{vstart + hostarch.Addr(mr.Start-vma.off), vstart + hostarch.Addr(mr.End-vma.off)}
}
// seekNextLowerBound returns mm.vmas.LowerBoundSegment(addr), but does so by
@@ -555,7 +555,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
// Preconditions:
// * mm.mappingMu must be locked.
// * addr >= vseg.Start().
-func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator {
+func (vseg vmaIterator) seekNextLowerBound(addr hostarch.Addr) vmaIterator {
if checkInvariants {
if !vseg.Ok() {
panic("terminal vma iterator")
@@ -572,7 +572,7 @@ func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator {
// availableRange returns the subset of vgap.Range() in which new vmas may be
// created without MMapOpts.Unmap == true.
-func (vgap vmaGapIterator) availableRange() usermem.AddrRange {
+func (vgap vmaGapIterator) availableRange() hostarch.AddrRange {
ar := vgap.Range()
next := vgap.NextSegment()
if !next.Ok() || !next.ValuePtr().growsDown {
@@ -580,7 +580,7 @@ func (vgap vmaGapIterator) availableRange() usermem.AddrRange {
}
// Exclude guard pages.
if ar.Length() < guardBytes {
- return usermem.AddrRange{ar.Start, ar.Start}
+ return hostarch.AddrRange{ar.Start, ar.Start}
}
ar.End -= guardBytes
return ar
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index e5bf13c40..57d73d770 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -85,6 +85,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/memutil",
"//pkg/safemem",
@@ -106,5 +107,5 @@ go_test(
size = "small",
srcs = ["pgalloc_test.go"],
library = ":pgalloc",
- deps = ["//pkg/usermem"],
+ deps = ["//pkg/hostarch"],
)
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index a4af3e21b..b81292c46 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -31,6 +31,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostmm"
@@ -38,7 +39,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// MemoryFile is a memmap.File whose pages may be allocated to arbitrary
@@ -283,7 +283,7 @@ const (
chunkMask = chunkSize - 1
// maxPage is the highest 64-bit page.
- maxPage = math.MaxUint64 &^ (usermem.PageSize - 1)
+ maxPage = math.MaxUint64 &^ (hostarch.PageSize - 1)
)
// NewMemoryFile creates a MemoryFile backed by the given file. If
@@ -344,7 +344,7 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
m, _, errno := unix.Syscall6(
unix.SYS_MMAP,
0,
- usermem.PageSize,
+ hostarch.PageSize,
unix.PROT_EXEC,
unix.MAP_SHARED,
file.Fd(),
@@ -357,7 +357,7 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
if _, _, errno := unix.Syscall(
unix.SYS_MUNMAP,
m,
- usermem.PageSize,
+ hostarch.PageSize,
0); errno != 0 {
panic(fmt.Sprintf("failed to unmap PROT_EXEC MemoryFile mapping: %v", errno))
}
@@ -386,7 +386,7 @@ func (f *MemoryFile) Destroy() {
//
// Preconditions: length must be page-aligned and non-zero.
func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) {
- if length == 0 || length%usermem.PageSize != 0 {
+ if length == 0 || length%hostarch.PageSize != 0 {
panic(fmt.Sprintf("invalid allocation length: %#x", length))
}
@@ -395,9 +395,9 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File
// Align hugepage-and-larger allocations on hugepage boundaries to try
// to take advantage of hugetmpfs.
- alignment := uint64(usermem.PageSize)
- if length >= usermem.HugePageSize {
- alignment = usermem.HugePageSize
+ alignment := uint64(hostarch.PageSize)
+ if length >= hostarch.HugePageSize {
+ alignment = hostarch.HugePageSize
}
// Find a range in the underlying file.
@@ -524,13 +524,13 @@ func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r saf
if err != nil {
return memmap.FileRange{}, err
}
- dsts, err := f.MapInternal(fr, usermem.Write)
+ dsts, err := f.MapInternal(fr, hostarch.Write)
if err != nil {
f.DecRef(fr)
return memmap.FileRange{}, err
}
n, err := safemem.ReadFullToBlocks(r, dsts)
- un := uint64(usermem.Addr(n).RoundDown())
+ un := uint64(hostarch.Addr(n).RoundDown())
if un < length {
// Free unused memory and update fr to contain only the memory that is
// still allocated.
@@ -552,7 +552,7 @@ const (
//
// Preconditions: fr.Length() > 0.
func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
- if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%hostarch.PageSize != 0 || fr.End%hostarch.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -614,7 +614,7 @@ func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
// IncRef implements memmap.File.IncRef.
func (f *MemoryFile) IncRef(fr memmap.FileRange) {
- if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%hostarch.PageSize != 0 || fr.End%hostarch.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -633,7 +633,7 @@ func (f *MemoryFile) IncRef(fr memmap.FileRange) {
// DecRef implements memmap.File.DecRef.
func (f *MemoryFile) DecRef(fr memmap.FileRange) {
- if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
+ if !fr.WellFormed() || fr.Length() == 0 || fr.Start%hostarch.PageSize != 0 || fr.End%hostarch.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -669,7 +669,7 @@ func (f *MemoryFile) DecRef(fr memmap.FileRange) {
}
// MapInternal implements memmap.File.MapInternal.
-func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (f *MemoryFile) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) {
if !fr.WellFormed() || fr.Length() == 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -935,7 +935,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
// Ensure that we have sufficient buffer for the call
// (one byte per page). The length of each slice must
// be page-aligned.
- bufLen := len(s) / usermem.PageSize
+ bufLen := len(s) / hostarch.PageSize
if len(buf) < bufLen {
buf = make([]byte, bufLen)
}
@@ -967,8 +967,8 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
}
}
committedFR := memmap.FileRange{
- Start: r.Start + uint64(i*usermem.PageSize),
- End: r.Start + uint64(j*usermem.PageSize),
+ Start: r.Start + uint64(i*hostarch.PageSize),
+ End: r.Start + uint64(j*hostarch.PageSize),
}
// Advance seg to committedFR.Start.
for seg.Ok() && seg.End() < committedFR.Start {
diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go
index 405db141f..8d2b7eb5e 100644
--- a/pkg/sentry/pgalloc/pgalloc_test.go
+++ b/pkg/sentry/pgalloc/pgalloc_test.go
@@ -17,12 +17,12 @@ package pgalloc
import (
"testing"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const (
- page = usermem.PageSize
- hugepage = usermem.HugePageSize
+ page = hostarch.PageSize
+ hugepage = hostarch.HugePageSize
topPage = (1 << 63) - page
)
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
index e05c8d074..345cdde55 100644
--- a/pkg/sentry/pgalloc/save_restore.go
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -23,11 +23,11 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/state"
"gvisor.dev/gvisor/pkg/state/wire"
- "gvisor.dev/gvisor/pkg/usermem"
)
// SaveTo writes f's state to the given stream.
@@ -49,11 +49,11 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error {
// Ensure that all pages that contain data have knownCommitted set, since
// we only store knownCommitted pages below.
- zeroPage := make([]byte, usermem.PageSize)
+ zeroPage := make([]byte, hostarch.PageSize)
err := f.updateUsageLocked(0, func(bs []byte, committed []byte) error {
- for pgoff := 0; pgoff < len(bs); pgoff += usermem.PageSize {
- i := pgoff / usermem.PageSize
- pg := bs[pgoff : pgoff+usermem.PageSize]
+ for pgoff := 0; pgoff < len(bs); pgoff += hostarch.PageSize {
+ i := pgoff / hostarch.PageSize
+ pg := bs[pgoff : pgoff+hostarch.PageSize]
if !bytes.Equal(pg, zeroPage) {
committed[i] = 1
continue
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index db7d55ef2..7125657b3 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -13,6 +13,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/hostmm",
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 03a76eb9b..b307832fd 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -43,6 +43,7 @@ go_library(
"//pkg/atomicbitops",
"//pkg/context",
"//pkg/cpuid",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/procid",
"//pkg/ring0",
@@ -56,7 +57,6 @@ go_library(
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/time",
"//pkg/sync",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
@@ -65,6 +65,7 @@ go_test(
name = "kvm_test",
srcs = [
"kvm_amd64_test.go",
+ "kvm_amd64_test.s",
"kvm_arm64_test.go",
"kvm_test.go",
"virtual_map_test.go",
@@ -76,6 +77,7 @@ go_test(
"requires-kvm",
],
deps = [
+ "//pkg/hostarch",
"//pkg/ring0",
"//pkg/ring0/pagetables",
"//pkg/sentry/arch",
@@ -83,7 +85,6 @@ go_test(
"//pkg/sentry/platform",
"//pkg/sentry/platform/kvm/testutil",
"//pkg/sentry/time",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index 25c21e843..5524e8727 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -18,11 +18,11 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// dirtySet tracks vCPUs for invalidation.
@@ -118,7 +118,7 @@ type hostMapEntry struct {
// +checkescape:hard,stack
//
//go:nosplit
-func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
+func (as *addressSpace) mapLocked(addr hostarch.Addr, m hostMapEntry, at hostarch.AccessType) (inv bool) {
for m.length > 0 {
physical, length, ok := translateToPhysical(m.addr)
if !ok {
@@ -144,14 +144,14 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.
}, physical) || inv
m.addr += length
m.length -= length
- addr += usermem.Addr(length)
+ addr += hostarch.Addr(length)
}
return inv
}
// MapFile implements platform.AddressSpace.MapFile.
-func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
+func (as *addressSpace) MapFile(addr hostarch.Addr, f memmap.File, fr memmap.FileRange, at hostarch.AccessType, precommit bool) error {
as.mu.Lock()
defer as.mu.Unlock()
@@ -165,7 +165,7 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File
// We don't execute from application file-mapped memory, and guest page
// tables don't care if we have execute permission (but they do need pages
// to be readable).
- bs, err := f.MapInternal(fr, usermem.AccessType{
+ bs, err := f.MapInternal(fr, hostarch.AccessType{
Read: at.Read || at.Execute || precommit,
Write: at.Write,
})
@@ -187,7 +187,7 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File
// lookup in our host page tables for this translation.
if precommit {
s := b.ToSlice()
- for i := 0; i < len(s); i += usermem.PageSize {
+ for i := 0; i < len(s); i += hostarch.PageSize {
_ = s[i] // Touch to commit.
}
}
@@ -201,7 +201,7 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File
length: uintptr(b.Len()),
}, at)
inv = inv || prev
- addr += usermem.Addr(b.Len())
+ addr += hostarch.Addr(b.Len())
}
if inv {
as.invalidate()
@@ -215,12 +215,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.File
// +checkescape:hard,stack
//
//go:nosplit
-func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool {
+func (as *addressSpace) unmapLocked(addr hostarch.Addr, length uint64) bool {
return as.pageTables.Unmap(addr, uintptr(length))
}
// Unmap unmaps the given range by calling pagetables.PageTables.Unmap.
-func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) {
+func (as *addressSpace) Unmap(addr hostarch.Addr, length uint64) {
as.mu.Lock()
defer as.mu.Unlock()
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index fd1131638..bb9967b9f 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -16,7 +16,6 @@ package kvm
import (
"fmt"
- "reflect"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/ring0"
@@ -36,6 +35,14 @@ func sighandler()
// dieArchSetup and the assembly implementation for dieTrampoline.
func dieTrampoline()
+// Return the start address of the functions above.
+//
+// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal
+// wrapper function rather than the function itself. We must reference from
+// assembly to get the ABI0 (i.e., primary) address.
+func addrOfSighandler() uintptr
+func addrOfDieTrampoline() uintptr
+
var (
// bounceSignal is the signal used for bouncing KVM.
//
@@ -87,10 +94,10 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
func init() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
+ if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
// Extract the address for the trampoline.
- dieTrampolineAddr = reflect.ValueOf(dieTrampoline).Pointer()
+ dieTrampolineAddr = addrOfDieTrampoline()
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index 025ea93b5..953024600 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -81,8 +81,20 @@ fallback:
MOVQ ·savedHandler(SB), AX
JMP AX
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSighandler(SB), $0-8
+ MOVQ $·sighandler(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
PUSHQ BX // First argument (vCPU).
PUSHQ AX // Fake the old RIP as caller.
JMP ·dieHandler(SB)
+
+// func addrOfDieTrampoline() uintptr
+TEXT ·addrOfDieTrampoline(SB), $0-8
+ MOVQ $·dieTrampoline(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index 09c7e88e5..308f2a951 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -92,6 +92,12 @@ fallback:
MOVD ·savedHandler(SB), R7
B (R7)
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSighandler(SB), $0-8
+ MOVD $·sighandler(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
// R0: Fake the old PC as caller
@@ -99,3 +105,9 @@ TEXT ·dieTrampoline(SB),NOSPLIT,$0
MOVD.P R1, 8(RSP) // R1: First argument (vCPU)
MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller
B ·dieHandler(SB)
+
+// func addrOfDieTrampoline() uintptr
+TEXT ·addrOfDieTrampoline(SB), $0-8
+ MOVD $·dieTrampoline(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index 37c53fa02..28a613a54 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -18,7 +18,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -47,7 +47,7 @@ func yield() {
//
//go:nosplit
func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) {
- alignedPhysical := physical &^ uintptr(usermem.PageSize-1)
+ alignedPhysical := physical &^ uintptr(hostarch.PageSize-1)
for _, pr := range phyRegions {
end := pr.physical + pr.length
if physical < pr.physical || physical >= end {
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index 706fa53dc..f4d4473a8 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -18,11 +18,11 @@ import (
"sync/atomic"
pkgcontext "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
- "gvisor.dev/gvisor/pkg/usermem"
)
// context is an implementation of the platform context.
@@ -40,7 +40,7 @@ type context struct {
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, hostarch.AccessType, error) {
as := mm.AddressSpace()
localAS := as.(*addressSpace)
@@ -50,7 +50,7 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a
// Enable interrupts (i.e. calls to vCPU.Notify).
if !c.interrupt.Enable(cpu) {
c.machine.Put(cpu) // Already preempted.
- return nil, usermem.NoAccess, platform.ErrContextInterrupt
+ return nil, hostarch.NoAccess, platform.ErrContextInterrupt
}
// Set the active address space.
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index 92c05a9ad..aac0fdffe 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -20,11 +20,11 @@ import (
"os"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// userMemoryRegion is a region of physical memory.
@@ -146,13 +146,13 @@ func (*KVM) MapUnit() uint64 {
}
// MinUserAddress returns the lowest available address.
-func (*KVM) MinUserAddress() usermem.Addr {
- return usermem.PageSize
+func (*KVM) MinUserAddress() hostarch.Addr {
+ return hostarch.PageSize
}
// MaxUserAddress returns the first address that may not be used.
-func (*KVM) MaxUserAddress() usermem.Addr {
- return usermem.Addr(ring0.MaximumUserAddress)
+func (*KVM) MaxUserAddress() hostarch.Addr {
+ return hostarch.Addr(ring0.MaximumUserAddress)
}
// NewAddressSpace returns a new pagetable root.
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index e44e995a0..b8dd1e4a5 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -49,3 +49,40 @@ func TestSegments(t *testing.T) {
return false
})
}
+
+// stmxcsr reads the MXCSR control and status register.
+func stmxcsr(addr *uint32)
+
+func TestMXCSR(t *testing.T) {
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ var si arch.SignalInfo
+ switchOpts := ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: &dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }
+
+ const mxcsrControllMask = uint32(0x1f80)
+ mxcsrBefore := uint32(0)
+ mxcsrAfter := uint32(0)
+ stmxcsr(&mxcsrBefore)
+ if mxcsrBefore == 0 {
+ // goruntime sets mxcsr to 0x1f80 and it never changes
+ // the control configuration.
+ panic("mxcsr is zero")
+ }
+ switchOpts.FloatingPointState.SetMXCSR(0)
+ if _, err := c.SwitchToUser(
+ switchOpts, &si); err == platform.ErrContextInterrupt {
+ return true // Retry.
+ } else if err != nil {
+ t.Errorf("application syscall failed: %v", err)
+ }
+ stmxcsr(&mxcsrAfter)
+ if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask {
+ t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter)
+ }
+ return false
+ })
+}
diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/sentry/platform/kvm/kvm_amd64_test.s
index d0f58cfaf..8e9079867 100644
--- a/pkg/tcpip/transport/tcp/cubic_state.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.s
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,18 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcp
+#include "textflag.h"
-import (
- "time"
-)
-
-// saveT is invoked by stateify.
-func (c *cubicState) saveT() unixTime {
- return unixTime{c.t.Unix(), c.t.UnixNano()}
-}
-
-// loadT is invoked by stateify.
-func (c *cubicState) loadT(unix unixTime) {
- c.t = time.Unix(unix.second, unix.nano)
-}
+// stmxcsr reads the MXCSR control and status register.
+TEXT ·stmxcsr(SB),NOSPLIT,$0-8
+ MOVQ addr+0(FP), SI
+ STMXCSR (SI)
+ RET
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 5bce16dde..ceff09a60 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -22,6 +22,7 @@ import (
"time"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -29,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
- "gvisor.dev/gvisor/pkg/usermem"
)
var dummyFPState = fpu.NewState()
@@ -142,8 +142,8 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func
// done for regular user code, but is fine for test
// purposes.)
applyPhysicalRegions(func(pr physicalRegion) bool {
- pt.Map(usermem.Addr(pr.virtual), pr.length, pagetables.MapOpts{
- AccessType: usermem.AnyAccess,
+ pt.Map(hostarch.Addr(pr.virtual), pr.length, pagetables.MapOpts{
+ AccessType: hostarch.AnyAccess,
User: true,
}, pr.physical)
return true // Keep iterating.
@@ -351,7 +351,7 @@ func TestInvalidate(t *testing.T) {
break // Done.
}
// Unmap the page containing data & invalidate.
- pt.Unmap(usermem.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(usermem.PageSize-1)), usermem.PageSize)
+ pt.Unmap(hostarch.Addr(reflect.ValueOf(&data).Pointer() & ^uintptr(hostarch.PageSize-1)), hostarch.PageSize)
for {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index cc0fb9892..99f036bba 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -21,13 +21,13 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// machine contains state associated with the VM as a whole.
@@ -235,9 +235,9 @@ func newMachine(vm int) (*machine, error) {
applyPhysicalRegions(func(pr physicalRegion) bool {
// Map everything in the lower half.
m.kernel.PageTables.Map(
- usermem.Addr(pr.virtual),
+ hostarch.Addr(pr.virtual),
pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pagetables.MapOpts{AccessType: hostarch.AnyAccess},
pr.physical)
return true // Keep iterating.
@@ -444,7 +444,7 @@ func (m *machine) Get() *vCPU {
}
// The vCPU is not be able to transition to
- // vCPUGuest|vCPUUser or to vCPUUser because that
+ // vCPUGuest|vCPUWaiter or to vCPUUser because that
// transition requires holding the machine mutex, as we
// do now. There is no path to register a waiter on
// just the vCPUReady state.
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index 31f56e3c2..d7abfefb4 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -24,13 +24,13 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
- "gvisor.dev/gvisor/pkg/usermem"
)
// initArchState initializes architecture-specific state.
@@ -41,7 +41,7 @@ func (m *machine) initArchState() error {
unix.SYS_IOCTL,
uintptr(m.fd),
_KVM_SET_TSS_ADDR,
- uintptr(reservedMemory-(3*usermem.PageSize))); errno != 0 {
+ uintptr(reservedMemory-(3*hostarch.PageSize))); errno != 0 {
return errno
}
@@ -261,19 +261,19 @@ func (c *vCPU) setSystemTime() error {
// nonCanonical generates a canonical address return.
//
//go:nosplit
-func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
*info = arch.SignalInfo{
Signo: signal,
Code: arch.SignalInfoKernel,
}
info.SetAddr(addr) // Include address.
- return usermem.NoAccess, platform.ErrContextSignal
+ return hostarch.NoAccess, platform.ErrContextSignal
}
// fault generates an appropriate fault return.
//
//go:nosplit
-func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
bluepill(c) // Probably no-op, but may not be.
faultAddr := ring0.ReadCR2()
code, user := c.ErrorCode()
@@ -281,12 +281,12 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, e
// The last fault serviced by this CPU was not a user
// fault, so we can't reliably trust the faultAddr or
// the code provided here. We need to re-execute.
- return usermem.NoAccess, platform.ErrContextInterrupt
+ return hostarch.NoAccess, platform.ErrContextInterrupt
}
// Reset the pointed SignalInfo.
*info = arch.SignalInfo{Signo: signal}
info.SetAddr(uint64(faultAddr))
- accessType := usermem.AccessType{
+ accessType := hostarch.AccessType{
Read: code&(1<<1) == 0,
Write: code&(1<<1) != 0,
Execute: code&(1<<4) != 0,
@@ -315,14 +315,14 @@ func loadByte(ptr *byte) byte {
//go:nosplit
func prefaultFloatingPointState(data *fpu.State) {
size := len(*data)
- for i := 0; i < size; i += usermem.PageSize {
+ for i := 0; i < size; i += hostarch.PageSize {
loadByte(&(*data)[i])
}
loadByte(&(*data)[size-1])
}
// SwitchToUser unpacks architectural-details.
-func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Rip) {
return nonCanonical(regs.Rip, int32(unix.SIGSEGV), info)
@@ -358,7 +358,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
switch vector {
case ring0.Syscall, ring0.SyscallInt80:
// Fast path: system call executed.
- return usermem.NoAccess, nil
+ return hostarch.NoAccess, nil
case ring0.PageFault:
return c.fault(int32(unix.SIGSEGV), info)
@@ -369,7 +369,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
Code: 1, // TRAP_BRKPT (breakpoint).
}
info.SetAddr(switchOpts.Registers.Rip) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
+ return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.GeneralProtectionFault,
ring0.SegmentNotPresent,
@@ -385,9 +385,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
// When CPUID faulting is enabled, we will generate a #GP(0) when
// userspace executes a CPUID instruction. This is handled above,
// because we need to be able to map and read user memory.
- return usermem.AccessType{}, platform.ErrContextSignalCPUID
+ return hostarch.AccessType{}, platform.ErrContextSignalCPUID
}
- return usermem.AccessType{}, platform.ErrContextSignal
+ return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.InvalidOpcode:
*info = arch.SignalInfo{
@@ -395,7 +395,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
Code: 1, // ILL_ILLOPC (illegal opcode).
}
info.SetAddr(switchOpts.Registers.Rip) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
+ return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.DivideByZero:
*info = arch.SignalInfo{
@@ -403,7 +403,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
Code: 1, // FPE_INTDIV (divide by zero).
}
info.SetAddr(switchOpts.Registers.Rip) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
+ return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.Overflow:
*info = arch.SignalInfo{
@@ -411,7 +411,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
Code: 2, // FPE_INTOVF (integer overflow).
}
info.SetAddr(switchOpts.Registers.Rip) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
+ return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.X87FloatingPointException,
ring0.SIMDFloatingPointException:
@@ -420,17 +420,17 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
Code: 7, // FPE_FLTINV (invalid operation).
}
info.SetAddr(switchOpts.Registers.Rip) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
+ return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.Vector(bounce): // ring0.VirtualizationException
- return usermem.NoAccess, platform.ErrContextInterrupt
+ return hostarch.NoAccess, platform.ErrContextInterrupt
case ring0.AlignmentCheck:
*info = arch.SignalInfo{
Signo: int32(unix.SIGBUS),
Code: 2, // BUS_ADRERR (physical address does not exist).
}
- return usermem.NoAccess, platform.ErrContextSignal
+ return hostarch.NoAccess, platform.ErrContextSignal
case ring0.NMI:
// An NMI is generated only when a fault is not servicable by
@@ -476,9 +476,9 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
panic("impossible translation")
}
pageTable.Map(
- usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ hostarch.Addr(ring0.KernelStartAddress|r.virtual),
r.length,
- pagetables.MapOpts{AccessType: usermem.Execute},
+ pagetables.MapOpts{AccessType: hostarch.Execute},
physical)
}
})
@@ -489,9 +489,9 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
panic("impossible translation")
}
pageTable.Map(
- usermem.Addr(ring0.KernelStartAddress|start),
+ hostarch.Addr(ring0.KernelStartAddress|start),
regionLen,
- pagetables.MapOpts{AccessType: usermem.ReadWrite},
+ pagetables.MapOpts{AccessType: hostarch.ReadWrite},
physical)
}
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 2edc9d1b2..cd912f922 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -17,12 +17,12 @@
package kvm
import (
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/usermem"
)
type vCPUArchState struct {
@@ -47,15 +47,15 @@ const (
// Beyond a relatively small number, there are likely few perform
// benefits, since the TLB has likely long since lost any translations
// from more than a few PCIDs past.
- poolPCIDs = 8
+ poolPCIDs = 128
)
func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
applyPhysicalRegions(func(pr physicalRegion) bool {
pageTable.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ hostarch.Addr(ring0.KernelStartAddress|pr.virtual),
pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess, Global: true},
+ pagetables.MapOpts{AccessType: hostarch.AnyAccess, Global: true},
pr.physical)
return true // Keep iterating.
@@ -117,13 +117,13 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
// nonCanonical generates a canonical address return.
//
//go:nosplit
-func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
*info = arch.SignalInfo{
Signo: signal,
Code: arch.SignalInfoKernel,
}
info.SetAddr(addr) // Include address.
- return usermem.NoAccess, platform.ErrContextSignal
+ return hostarch.NoAccess, platform.ErrContextSignal
}
// isInstructionAbort returns true if it is an instruction abort.
@@ -148,7 +148,7 @@ func isWriteFault(code uint64) bool {
// fault generates an appropriate fault return.
//
//go:nosplit
-func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, error) {
bluepill(c) // Probably no-op, but may not be.
faultAddr := c.GetFaultAddr()
code, user := c.ErrorCode()
@@ -157,7 +157,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, e
// The last fault serviced by this CPU was not a user
// fault, so we can't reliably trust the faultAddr or
// the code provided here. We need to re-execute.
- return usermem.NoAccess, platform.ErrContextInterrupt
+ return hostarch.NoAccess, platform.ErrContextInterrupt
}
// Reset the pointed SignalInfo.
@@ -174,7 +174,7 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, e
info.Code = 2
}
- accessType := usermem.AccessType{
+ accessType := hostarch.AccessType{
Read: !isWriteFault(uint64(code)),
Write: isWriteFault(uint64(code)),
Execute: isInstructionAbort(uint64(code)),
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index e7d5f3193..634e55ec0 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -23,12 +23,12 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/usermem"
)
type kvmVcpuInit struct {
@@ -209,7 +209,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
}
// SwitchToUser unpacks architectural-details.
-func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
return nonCanonical(regs.Pc, int32(unix.SIGSEGV), info)
@@ -246,13 +246,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
switch vector {
case ring0.Syscall:
// Fast path: system call executed.
- return usermem.NoAccess, nil
+ return hostarch.NoAccess, nil
case ring0.PageFault:
return c.fault(int32(unix.SIGSEGV), info)
case ring0.El0ErrNMI:
return c.fault(int32(unix.SIGBUS), info)
case ring0.Vector(bounce): // ring0.VirtualizationException.
- return usermem.NoAccess, platform.ErrContextInterrupt
+ return hostarch.NoAccess, platform.ErrContextInterrupt
case ring0.El0SyncUndef:
return c.fault(int32(unix.SIGILL), info)
case ring0.El0SyncDbg:
@@ -261,16 +261,16 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
Code: 1, // TRAP_BRKPT (breakpoint).
}
info.SetAddr(switchOpts.Registers.Pc) // Include address.
- return usermem.AccessType{}, platform.ErrContextSignal
+ return hostarch.AccessType{}, platform.ErrContextSignal
case ring0.El0SyncSpPc:
*info = arch.SignalInfo{
Signo: int32(unix.SIGBUS),
Code: 2, // BUS_ADRERR (physical address does not exist).
}
- return usermem.NoAccess, platform.ErrContextSignal
+ return hostarch.NoAccess, platform.ErrContextSignal
case ring0.El0SyncSys,
ring0.El0SyncWfx:
- return usermem.NoAccess, nil // skip for now.
+ return hostarch.NoAccess, nil // skip for now.
default:
panic(fmt.Sprintf("unexpected vector: 0x%x", vector))
}
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
index 7376d8b8d..d812e6c26 100644
--- a/pkg/sentry/platform/kvm/physical_map.go
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -19,9 +19,9 @@ import (
"sort"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/ring0"
- "gvisor.dev/gvisor/pkg/usermem"
)
type region struct {
@@ -81,7 +81,7 @@ func fillAddressSpace() (excludedRegions []region) {
// faultBlockSize, potentially causing up to faultBlockSize bytes in
// internal fragmentation for each physical region. So we need to
// account for this properly during allocation.
- requiredAddr, ok := usermem.Addr(vSize - pSize + faultBlockSize).RoundUp()
+ requiredAddr, ok := hostarch.Addr(vSize - pSize + faultBlockSize).RoundUp()
if !ok {
panic(fmt.Sprintf(
"overflow for vSize (%x) - pSize (%x) + faultBlockSize (%x)",
@@ -99,7 +99,7 @@ func fillAddressSpace() (excludedRegions []region) {
0, 0)
if errno != 0 {
// Attempt half the size; overflow not possible.
- currentAddr, _ := usermem.Addr(current >> 1).RoundUp()
+ currentAddr, _ := hostarch.Addr(current >> 1).RoundUp()
current = uintptr(currentAddr)
continue
}
@@ -134,8 +134,8 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica
return
}
if virtual == 0 {
- virtual += usermem.PageSize
- length -= usermem.PageSize
+ virtual += hostarch.PageSize
+ length -= hostarch.PageSize
}
if end := virtual + length; end > ring0.MaximumUserAddress {
length -= (end - ring0.MaximumUserAddress)
diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go
index 4dcdbf8a7..01d9eb39d 100644
--- a/pkg/sentry/platform/kvm/virtual_map.go
+++ b/pkg/sentry/platform/kvm/virtual_map.go
@@ -22,12 +22,12 @@ import (
"regexp"
"strconv"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
type virtualRegion struct {
region
- accessType usermem.AccessType
+ accessType hostarch.AccessType
shared bool
offset uintptr
filename string
@@ -92,7 +92,7 @@ func applyVirtualRegions(fn func(vr virtualRegion)) error {
virtual: uintptr(start),
length: uintptr(end - start),
},
- accessType: usermem.AccessType{
+ accessType: hostarch.AccessType{
Read: read,
Write: write,
Execute: execute,
diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go
index 9b4545fdd..1f4a774f3 100644
--- a/pkg/sentry/platform/kvm/virtual_map_test.go
+++ b/pkg/sentry/platform/kvm/virtual_map_test.go
@@ -18,12 +18,12 @@ import (
"testing"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
type checker struct {
ok bool
- accessType usermem.AccessType
+ accessType hostarch.AccessType
}
func (c *checker) Containing(addr uintptr) func(virtualRegion) {
@@ -46,7 +46,7 @@ func TestParseMaps(t *testing.T) {
// MMap a new page.
addr, _, errno := unix.RawSyscall6(
- unix.SYS_MMAP, 0, usermem.PageSize,
+ unix.SYS_MMAP, 0, hostarch.PageSize,
unix.PROT_READ|unix.PROT_WRITE,
unix.MAP_ANONYMOUS|unix.MAP_PRIVATE, 0, 0)
if errno != 0 {
@@ -55,19 +55,19 @@ func TestParseMaps(t *testing.T) {
// Re-parse maps.
if err := applyVirtualRegions(c.Containing(addr)); err != nil {
- unix.RawSyscall(unix.SYS_MUNMAP, addr, usermem.PageSize, 0)
+ unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0)
t.Fatalf("unexpected error: %v", err)
}
// Assert that it now does contain the region.
if !c.ok {
- unix.RawSyscall(unix.SYS_MUNMAP, addr, usermem.PageSize, 0)
+ unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0)
t.Fatalf("updated map does not contain 0x%08x, expected true", addr)
}
// Map the region as PROT_NONE.
newAddr, _, errno := unix.RawSyscall6(
- unix.SYS_MMAP, addr, usermem.PageSize,
+ unix.SYS_MMAP, addr, hostarch.PageSize,
unix.PROT_NONE,
unix.MAP_ANONYMOUS|unix.MAP_FIXED|unix.MAP_PRIVATE, 0, 0)
if errno != 0 {
@@ -89,5 +89,5 @@ func TestParseMaps(t *testing.T) {
}
// Unmap the region.
- unix.RawSyscall(unix.SYS_MUNMAP, addr, usermem.PageSize, 0)
+ unix.RawSyscall(unix.SYS_MUNMAP, addr, hostarch.PageSize, 0)
}
diff --git a/pkg/sentry/platform/mmap_min_addr.go b/pkg/sentry/platform/mmap_min_addr.go
index 091c2e365..7335bd802 100644
--- a/pkg/sentry/platform/mmap_min_addr.go
+++ b/pkg/sentry/platform/mmap_min_addr.go
@@ -20,7 +20,7 @@ import (
"strconv"
"strings"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// systemMMapMinAddrSource is the source file.
@@ -30,8 +30,8 @@ const systemMMapMinAddrSource = "/proc/sys/vm/mmap_min_addr"
var systemMMapMinAddr uint64
// SystemMMapMinAddr returns the minimum system address.
-func SystemMMapMinAddr() usermem.Addr {
- return usermem.Addr(systemMMapMinAddr)
+func SystemMMapMinAddr() hostarch.Addr {
+ return hostarch.Addr(systemMMapMinAddr)
}
// MMapMinAddr is a size zero struct that implements MinUserAddress based on
@@ -41,7 +41,7 @@ type MMapMinAddr struct {
}
// MinUserAddress implements platform.MinUserAddresss.
-func (*MMapMinAddr) MinUserAddress() usermem.Addr {
+func (*MMapMinAddr) MinUserAddress() hostarch.Addr {
return SystemMMapMinAddr()
}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index dcfe839a7..ef7814a6f 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/hostmm"
@@ -62,16 +63,16 @@ type Platform interface {
// for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates
// that the cost of AddressSpace.MapFile is effectively independent of the
// number of pages mapped. If MapUnit is non-zero, it must be a power-of-2
- // multiple of usermem.PageSize.
+ // multiple of hostarch.PageSize.
MapUnit() uint64
// MinUserAddress returns the minimum mappable address on this
// platform.
- MinUserAddress() usermem.Addr
+ MinUserAddress() hostarch.Addr
// MaxUserAddress returns the maximum mappable address on this
// platform.
- MaxUserAddress() usermem.Addr
+ MaxUserAddress() hostarch.Addr
// NewAddressSpace returns a new memory context for this platform.
//
@@ -172,7 +173,7 @@ type MemoryManager interface {
//usermem.IO provides access to the contents of a virtual memory space.
usermem.IO
// MMap establishes a memory mapping.
- MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error)
+ MMap(ctx context.Context, opts memmap.MMapOpts) (hostarch.Addr, error)
// AddressSpace returns the AddressSpace bound to mm.
AddressSpace() AddressSpace
}
@@ -195,7 +196,7 @@ type Context interface {
//
// - ErrContextSignal: The Context was interrupted by a signal. The
// returned *arch.SignalInfo contains information about the signal. If
- // arch.SignalInfo.Signo == SIGSEGV, the returned usermem.AccessType
+ // arch.SignalInfo.Signo == SIGSEGV, the returned hostarch.AccessType
// contains the access type of the triggering fault. The caller owns
// the returned SignalInfo.
//
@@ -206,7 +207,7 @@ type Context interface {
// concurrent call to Switch().
//
// - ErrContextCPUPreempted: See the definition of that error for details.
- Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error)
+ Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error)
// PullFullState() pulls a full state of the application thread.
//
@@ -302,14 +303,14 @@ type AddressSpace interface {
// * at.Any() == true.
// * At least one reference must be held on all pages in fr, and must
// continue to be held as long as pages are mapped.
- MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error
+ MapFile(addr hostarch.Addr, f memmap.File, fr memmap.FileRange, at hostarch.AccessType, precommit bool) error
// Unmap unmaps the given range.
//
// Preconditions:
// * addr is page-aligned.
// * length > 0.
- Unmap(addr usermem.Addr, length uint64)
+ Unmap(addr hostarch.Addr, length uint64)
// Release releases this address space. After releasing, a new AddressSpace
// must be acquired via platform.NewAddressSpace().
@@ -337,67 +338,67 @@ type AddressSpaceIO interface {
// CopyOut copies len(src) bytes from src to the memory mapped at addr. It
// returns the number of bytes copied. If the number of bytes copied is <
// len(src), it returns a non-nil error explaining why.
- CopyOut(addr usermem.Addr, src []byte) (int, error)
+ CopyOut(addr hostarch.Addr, src []byte) (int, error)
// CopyIn copies len(dst) bytes from the memory mapped at addr to dst.
// It returns the number of bytes copied. If the number of bytes copied is
// < len(dst), it returns a non-nil error explaining why.
- CopyIn(addr usermem.Addr, dst []byte) (int, error)
+ CopyIn(addr hostarch.Addr, dst []byte) (int, error)
// ZeroOut sets toZero bytes to 0, starting at addr. It returns the number
// of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a
// non-nil error explaining why.
- ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error)
+ ZeroOut(addr hostarch.Addr, toZero uintptr) (uintptr, error)
// SwapUint32 atomically sets the uint32 value at addr to new and returns
// the previous value.
//
// Preconditions: addr must be aligned to a 4-byte boundary.
- SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
+ SwapUint32(addr hostarch.Addr, new uint32) (uint32, error)
// CompareAndSwapUint32 atomically compares the uint32 value at addr to
// old; if they are equal, the value in memory is replaced by new. In
// either case, the previous value stored in memory is returned.
//
// Preconditions: addr must be aligned to a 4-byte boundary.
- CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error)
+ CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error)
// LoadUint32 atomically loads the uint32 value at addr and returns it.
//
// Preconditions: addr must be aligned to a 4-byte boundary.
- LoadUint32(addr usermem.Addr) (uint32, error)
+ LoadUint32(addr hostarch.Addr) (uint32, error)
}
// NoAddressSpaceIO implements AddressSpaceIO methods by panicking.
type NoAddressSpaceIO struct{}
// CopyOut implements AddressSpaceIO.CopyOut.
-func (NoAddressSpaceIO) CopyOut(addr usermem.Addr, src []byte) (int, error) {
+func (NoAddressSpaceIO) CopyOut(addr hostarch.Addr, src []byte) (int, error) {
panic("This platform does not support AddressSpaceIO")
}
// CopyIn implements AddressSpaceIO.CopyIn.
-func (NoAddressSpaceIO) CopyIn(addr usermem.Addr, dst []byte) (int, error) {
+func (NoAddressSpaceIO) CopyIn(addr hostarch.Addr, dst []byte) (int, error) {
panic("This platform does not support AddressSpaceIO")
}
// ZeroOut implements AddressSpaceIO.ZeroOut.
-func (NoAddressSpaceIO) ZeroOut(addr usermem.Addr, toZero uintptr) (uintptr, error) {
+func (NoAddressSpaceIO) ZeroOut(addr hostarch.Addr, toZero uintptr) (uintptr, error) {
panic("This platform does not support AddressSpaceIO")
}
// SwapUint32 implements AddressSpaceIO.SwapUint32.
-func (NoAddressSpaceIO) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
+func (NoAddressSpaceIO) SwapUint32(addr hostarch.Addr, new uint32) (uint32, error) {
panic("This platform does not support AddressSpaceIO")
}
// CompareAndSwapUint32 implements AddressSpaceIO.CompareAndSwapUint32.
-func (NoAddressSpaceIO) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
+func (NoAddressSpaceIO) CompareAndSwapUint32(addr hostarch.Addr, old, new uint32) (uint32, error) {
panic("This platform does not support AddressSpaceIO")
}
// LoadUint32 implements AddressSpaceIO.LoadUint32.
-func (NoAddressSpaceIO) LoadUint32(addr usermem.Addr) (uint32, error) {
+func (NoAddressSpaceIO) LoadUint32(addr hostarch.Addr) (uint32, error) {
panic("This platform does not support AddressSpaceIO")
}
@@ -406,7 +407,7 @@ func (NoAddressSpaceIO) LoadUint32(addr usermem.Addr) (uint32, error) {
// permissions.
type SegmentationFault struct {
// Addr is the address at which the fault occurred.
- Addr usermem.Addr
+ Addr hostarch.Addr
}
// Error implements error.Error.
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 47efde6a2..d101f2f53 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -25,6 +25,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/procid",
"//pkg/safecopy",
@@ -35,7 +36,6 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sync",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 571bfcc2e..828458ce2 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -49,11 +49,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
pkgcontext "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
var (
@@ -88,28 +88,28 @@ type context struct {
// lastFaultAddr is the last faulting address; this is only meaningful if
// lastFaultSP is non-nil.
- lastFaultAddr usermem.Addr
+ lastFaultAddr hostarch.Addr
// lastFaultIP is the address of the last faulting instruction;
// this is also only meaningful if lastFaultSP is non-nil.
- lastFaultIP usermem.Addr
+ lastFaultIP hostarch.Addr
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, hostarch.AccessType, error) {
as := mm.AddressSpace()
s := as.(*subprocess)
isSyscall := s.switchToApp(c, ac)
var (
faultSP *subprocess
- faultAddr usermem.Addr
- faultIP usermem.Addr
+ faultAddr hostarch.Addr
+ faultIP hostarch.Addr
)
if !isSyscall && linux.Signal(c.signalInfo.Signo) == linux.SIGSEGV {
faultSP = s
- faultAddr = usermem.Addr(c.signalInfo.Addr())
- faultIP = usermem.Addr(ac.IP())
+ faultAddr = hostarch.Addr(c.signalInfo.Addr())
+ faultIP = hostarch.Addr(ac.IP())
}
// Update the context to reflect the outcome of this context switch.
@@ -140,14 +140,14 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a
}
if isSyscall {
- return nil, usermem.NoAccess, nil
+ return nil, hostarch.NoAccess, nil
}
si := c.signalInfo
if faultSP == nil {
// Non-fault signal.
- return &si, usermem.NoAccess, platform.ErrContextSignal
+ return &si, hostarch.NoAccess, platform.ErrContextSignal
}
// Got a page fault. Ideally, we'd get real fault type here, but ptrace
@@ -157,7 +157,7 @@ func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac a
// pointer.
//
// It was a write fault if the fault is immediately repeated.
- at := usermem.Read
+ at := hostarch.Read
if faultAddr == faultIP {
at.Execute = true
}
@@ -235,8 +235,8 @@ func (*PTrace) MapUnit() uint64 {
// MaxUserAddress returns the first address that may not be used by user
// applications.
-func (*PTrace) MaxUserAddress() usermem.Addr {
- return usermem.Addr(stubStart)
+func (*PTrace) MaxUserAddress() hostarch.Addr {
+ return hostarch.Addr(stubStart)
}
// NewAddressSpace returns a new subprocess.
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 01e73b019..facb96011 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -19,9 +19,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/arch/fpu"
- "gvisor.dev/gvisor/pkg/usermem"
)
// getRegs gets the general purpose register set.
@@ -122,7 +122,7 @@ func (t *thread) getSignalInfo(si *arch.SignalInfo) error {
//
// Precondition: the OS thread must be locked and own t.
func (t *thread) clone() (*thread, error) {
- r, ok := usermem.Addr(stackPointer(&t.initRegs)).RoundUp()
+ r, ok := hostarch.Addr(stackPointer(&t.initRegs)).RoundUp()
if !ok {
return nil, unix.EINVAL
}
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
index 16f9c523e..d5c3f901f 100644
--- a/pkg/sentry/platform/ptrace/stub_amd64.s
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -109,6 +109,12 @@ parent_dead:
SYSCALL
HLT
+// func addrOfStub() uintptr
+TEXT ·addrOfStub(SB), $0-8
+ MOVQ $·stub(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// stubCall calls the stub function at the given address with the given PPID.
//
// This is a distinct function because stub, above, may be mapped at any
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
index 6162df02a..4664cd4ad 100644
--- a/pkg/sentry/platform/ptrace/stub_arm64.s
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -102,6 +102,12 @@ parent_dead:
SVC
HLT
+// func addrOfStub() uintptr
+TEXT ·addrOfStub(SB), $0-8
+ MOVD $·stub(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// stubCall calls the stub function at the given address with the given PPID.
//
// This is a distinct function because stub, above, may be mapped at any
diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go
index 780227248..1fbdea898 100644
--- a/pkg/sentry/platform/ptrace/stub_unsafe.go
+++ b/pkg/sentry/platform/ptrace/stub_unsafe.go
@@ -19,13 +19,20 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safecopy"
- "gvisor.dev/gvisor/pkg/usermem"
)
// stub is defined in arch-specific assembly.
func stub()
+// addrOfStub returns the start address of stub.
+//
+// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal
+// wrapper function rather than the function itself. We must reference from
+// assembly to get the ABI0 (i.e., primary) address.
+func addrOfStub() uintptr
+
// stubCall calls the stub at the given address with the given pid.
func stubCall(addr, pid uintptr)
@@ -41,12 +48,12 @@ func unsafeSlice(addr uintptr, length int) (slice []byte) {
// stubInit initializes the stub.
func stubInit() {
// Grab the existing stub.
- stubBegin := reflect.ValueOf(stub).Pointer()
+ stubBegin := addrOfStub()
stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin)
stubSlice := unsafeSlice(stubBegin, stubLen)
mapLen := uintptr(stubLen)
- if offset := mapLen % usermem.PageSize; offset != 0 {
- mapLen += usermem.PageSize - offset
+ if offset := mapLen % hostarch.PageSize; offset != 0 {
+ mapLen += hostarch.PageSize - offset
}
for stubStart > 0 {
@@ -70,7 +77,7 @@ func stubInit() {
}
// Attempt to begin at a lower address.
- stubStart -= uintptr(usermem.PageSize)
+ stubStart -= uintptr(hostarch.PageSize)
continue
}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index acccbfe2e..9c73a725a 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -20,13 +20,13 @@ import (
"runtime"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Linux kernel errnos which "should never be seen by user programs", but will
@@ -69,7 +69,7 @@ type thread struct {
// threadPool is a collection of threads.
type threadPool struct {
// mu protects below.
- mu sync.Mutex
+ mu sync.RWMutex
// threads is the collection of threads.
//
@@ -85,30 +85,42 @@ type threadPool struct {
//
// Precondition: the runtime OS thread must be locked.
func (tp *threadPool) lookupOrCreate(currentTID int32, newThread func() *thread) *thread {
- tp.mu.Lock()
+ // The overwhelming common case is that the thread is already created.
+ // Optimistically attempt the lookup by only locking for reading.
+ tp.mu.RLock()
t, ok := tp.threads[currentTID]
- if !ok {
- // Before creating a new thread, see if we can find a thread
- // whose system tid has disappeared.
- //
- // TODO(b/77216482): Other parts of this package depend on
- // threads never exiting.
- for origTID, t := range tp.threads {
- // Signal zero is an easy existence check.
- if err := unix.Tgkill(unix.Getpid(), int(origTID), 0); err != nil {
- // This thread has been abandoned; reuse it.
- delete(tp.threads, origTID)
- tp.threads[currentTID] = t
- tp.mu.Unlock()
- return t
- }
- }
+ tp.mu.RUnlock()
+ if ok {
+ return t
+ }
- // Create a new thread.
- t = newThread()
- tp.threads[currentTID] = t
+ tp.mu.Lock()
+ defer tp.mu.Unlock()
+
+ // Another goroutine might have created the thread for currentTID in between
+ // mu.RUnlock() and mu.Lock().
+ if t, ok = tp.threads[currentTID]; ok {
+ return t
+ }
+
+ // Before creating a new thread, see if we can find a thread
+ // whose system tid has disappeared.
+ //
+ // TODO(b/77216482): Other parts of this package depend on
+ // threads never exiting.
+ for origTID, t := range tp.threads {
+ // Signal zero is an easy existence check.
+ if err := unix.Tgkill(unix.Getpid(), int(origTID), 0); err != nil {
+ // This thread has been abandoned; reuse it.
+ delete(tp.threads, origTID)
+ tp.threads[currentTID] = t
+ return t
+ }
}
- tp.mu.Unlock()
+
+ // Create a new thread.
+ t = newThread()
+ tp.threads[currentTID] = t
return t
}
@@ -228,7 +240,7 @@ func newSubprocess(create func() (*thread, error)) (*subprocess, error) {
func (s *subprocess) unmap() {
s.Unmap(0, uint64(stubStart))
if maximumUserAddress != stubEnd {
- s.Unmap(usermem.Addr(stubEnd), uint64(maximumUserAddress-stubEnd))
+ s.Unmap(hostarch.Addr(stubEnd), uint64(maximumUserAddress-stubEnd))
}
}
@@ -615,7 +627,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp
}
// MapFile implements platform.AddressSpace.MapFile.
-func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
+func (s *subprocess) MapFile(addr hostarch.Addr, f memmap.File, fr memmap.FileRange, at hostarch.AccessType, precommit bool) error {
var flags int
if precommit {
flags |= unix.MAP_POPULATE
@@ -632,7 +644,7 @@ func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRan
}
// Unmap implements platform.AddressSpace.Unmap.
-func (s *subprocess) Unmap(addr usermem.Addr, length uint64) {
+func (s *subprocess) Unmap(addr hostarch.Addr, length uint64) {
ar, ok := addr.ToRange(length)
if !ok {
panic(fmt.Sprintf("addr %#x + length %#x overflows", addr, length))
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 0ce42b6cc..080859125 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/marshal",
"//pkg/sentry/device",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index ebcc891b3..0e0e82365 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
@@ -23,7 +24,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/usermem",
],
)
@@ -35,8 +35,8 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/hostarch",
"//pkg/sentry/socket",
- "//pkg/usermem",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 65b556489..45a05cd63 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -20,13 +20,13 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
const maxInt = int(^uint(0) >> 1)
@@ -181,12 +181,12 @@ func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool {
}
func putUint64(buf []byte, n uint64) []byte {
- usermem.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n)
+ hostarch.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n)
return buf[:len(buf)+8]
}
func putUint32(buf []byte, n uint32) []byte {
- usermem.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n)
+ hostarch.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n)
return buf[:len(buf)+4]
}
@@ -242,7 +242,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
hdrBuf := buf
- buf = binary.Marshal(buf, usermem.ByteOrder, data)
+ buf = binary.Marshal(buf, hostarch.ByteOrder, data)
// If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
@@ -475,7 +475,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var h linux.ControlMessageHeader
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h)
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
return socket.ControlMessages{}, syserror.EINVAL
@@ -499,7 +499,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
- fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
}
i += binary.AlignUp(length, width)
@@ -510,7 +510,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var creds linux.ControlMessageCredentials
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds)
scmCreds, err := NewSCMCredentials(t, creds)
if err != nil {
return socket.ControlMessages{}, err
@@ -523,7 +523,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, syserror.EINVAL
}
var ts linux.Timeval
- binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &ts)
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts)
cmsgs.IP.Timestamp = ts.ToNsecCapped()
cmsgs.IP.HasTimestamp = true
i += binary.AlignUp(length, width)
@@ -539,7 +539,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, syserror.EINVAL
}
cmsgs.IP.HasTOS = true
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS)
i += binary.AlignUp(length, width)
case linux.IP_PKTINFO:
@@ -549,7 +549,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
cmsgs.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo)
cmsgs.IP.PacketInfo = packetInfo
i += binary.AlignUp(length, width)
@@ -559,7 +559,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
if length < addr.SizeBytes() {
return socket.ControlMessages{}, syserror.EINVAL
}
- binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr)
cmsgs.IP.OriginalDstAddress = &addr
i += binary.AlignUp(length, width)
@@ -583,7 +583,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, syserror.EINVAL
}
cmsgs.IP.HasTClass = true
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass)
i += binary.AlignUp(length, width)
case linux.IPV6_RECVORIGDSTADDR:
@@ -591,7 +591,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
if length < addr.SizeBytes() {
return socket.ControlMessages{}, syserror.EINVAL
}
- binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr)
cmsgs.IP.OriginalDstAddress = &addr
i += binary.AlignUp(length, width)
diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go
index d40a4cc85..7e28a0cef 100644
--- a/pkg/sentry/socket/control/control_test.go
+++ b/pkg/sentry/socket/control/control_test.go
@@ -22,8 +22,8 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/usermem"
)
func TestParse(t *testing.T) {
@@ -35,12 +35,12 @@ func TestParse(t *testing.T) {
Type: linux.SO_TIMESTAMP,
}
buf := make([]byte, 0, length)
- buf = binary.Marshal(buf, usermem.ByteOrder, &hdr)
+ buf = binary.Marshal(buf, hostarch.ByteOrder, &hdr)
ts := linux.Timeval{
Sec: 2401,
Usec: 343,
}
- buf = binary.Marshal(buf, usermem.ByteOrder, &ts)
+ buf = binary.Marshal(buf, hostarch.ByteOrder, &ts)
cmsg, err := Parse(nil, nil, buf, 8 /* width */)
if err != nil {
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index a8e6f172b..a5c2155a2 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/binary",
"//pkg/context",
"//pkg/fdnotifier",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 2d9dbbdba..a784e23b5 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
@@ -321,7 +322,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -527,24 +528,24 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.SO_TIMESTAMP:
controlMessages.IP.HasTimestamp = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp)
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], hostarch.ByteOrder, &controlMessages.IP.Timestamp)
}
case linux.SOL_IP:
switch unixCmsg.Header.Type {
case linux.IP_TOS:
controlMessages.IP.HasTOS = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS)
case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo)
controlMessages.IP.PacketInfo = packetInfo
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
- binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr)
controlMessages.IP.OriginalDstAddress = &addr
case unix.IP_RECVERR:
@@ -557,11 +558,11 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
- binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr)
controlMessages.IP.OriginalDstAddress = &addr
case unix.IPV6_RECVERR:
@@ -574,7 +575,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.TCP_INQ:
controlMessages.IP.HasInq = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq)
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq)
}
}
}
@@ -688,7 +689,7 @@ func (s *socketOpsCommon) State() uint32 {
return 0
}
- binary.Unmarshal(buf, usermem.ByteOrder, &info)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &info)
return uint32(info.State)
}
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
index 2890e640d..d3be2d825 100644
--- a/pkg/sentry/socket/hostinet/socket_unsafe.go
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -61,7 +62,7 @@ func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArgument
return 0, translateIOSyscallError(errno)
}
var buf [4]byte
- usermem.ByteOrder.PutUint32(buf[:], uint32(val))
+ hostarch.ByteOrder.PutUint32(buf[:], uint32(val))
_, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{
AddressSpaceActive: true,
})
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 5bcf92e14..26e8ae17a 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -22,11 +22,13 @@ import (
"reflect"
"strconv"
"strings"
+
"syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserr"
@@ -146,7 +148,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg)
}
var ifinfo unix.IfInfomsg
- binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], usermem.ByteOrder, &ifinfo)
+ binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo)
inetIF := inet.Interface{
DeviceType: ifinfo.Type,
Flags: ifinfo.Flags,
@@ -177,7 +179,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg)
}
var ifaddr unix.IfAddrmsg
- binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], usermem.ByteOrder, &ifaddr)
+ binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr)
inetAddr := inet.InterfaceAddr{
Family: ifaddr.Family,
PrefixLen: ifaddr.Prefixlen,
@@ -209,7 +211,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
}
var ifRoute unix.RtMsg
- binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], usermem.ByteOrder, &ifRoute)
+ binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute)
inetRoute := inet.Route{
Family: ifRoute.Family,
DstLen: ifRoute.Dst_len,
@@ -243,7 +245,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
if len(attr.Value) != expected {
return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
}
- binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface)
+ binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface)
}
}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 8aea0200f..4381dfa06 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -20,12 +20,12 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/sentry/kernel",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
- "//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index e339f9bea..4bd305a44 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,10 +19,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
// TODO(gvisor.dev/issue/170): The following per-matcher params should be
@@ -89,7 +89,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
copy(matcher.Name[:], name)
buf := make([]byte, 0, size)
- buf = binary.Marshal(buf, usermem.ByteOrder, matcher)
+ buf = binary.Marshal(buf, hostarch.ByteOrder, matcher)
return append(buf, make([]byte, size-len(buf))...)
}
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index 2f913787b..1fc4cb651 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -19,11 +19,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
// emptyIPv4Filter is for comparison with a rule's filters to determine whether
@@ -142,7 +142,7 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
}
var entry linux.IPTEntry
buf := optVal[:linux.SizeOfIPTEntry]
- binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &entry)
initialOptValLen := len(optVal)
optVal = optVal[linux.SizeOfIPTEntry:]
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 263d9d3b5..67a52b628 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -19,11 +19,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
// emptyIPv6Filter is for comparison with a rule's filters to determine whether
@@ -145,7 +145,7 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
}
var entry linux.IP6TEntry
buf := optVal[:linux.SizeOfIP6TEntry]
- binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &entry)
initialOptValLen := len(optVal)
optVal = optVal[linux.SizeOfIP6TEntry:]
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 7ae18b2a3..c6fa3fd16 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -23,12 +23,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
// enableLogging controls whether to log the (de)serialization of netfilter
@@ -83,7 +83,7 @@ func DefaultLinuxTables() *stack.IPTables {
}
// GetInfo returns information about iptables.
-func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
+func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
var info linux.IPTGetinfo
if _, err := info.CopyIn(t, outPtr); err != nil {
@@ -106,7 +106,7 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool)
}
// GetEntries4 returns netstack's iptables rules.
-func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
+func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
if _, err := userEntries.CopyIn(t, outPtr); err != nil {
@@ -130,7 +130,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
}
// GetEntries6 returns netstack's ip6tables rules.
-func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) {
+func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) {
// Read in the struct and table name. IPv4 and IPv6 utilize structs
// with the same layout.
var userEntries linux.IPTGetEntries
@@ -179,7 +179,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
replaceBuf := optVal[:linux.SizeOfIPTReplace]
optVal = optVal[linux.SizeOfIPTReplace:]
- binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+ binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace)
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
@@ -274,10 +274,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
- // make sure all other chains point to ACCEPT rules.
+ // Since we don't support FORWARD, yet, make sure all other chains point to
+ // ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if hook := stack.Hook(hook); hook == stack.Forward {
if ruleIdx == stack.HookUnset {
continue
}
@@ -310,7 +310,7 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
}
var match linux.XTEntryMatch
buf := optVal[:linux.SizeOfXTEntryMatch]
- binary.Unmarshal(buf, usermem.ByteOrder, &match)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &match)
nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
// Check some invariants.
@@ -381,7 +381,7 @@ func hookFromLinux(hook int) stack.Hook {
// TargetRevision returns a linux.XTGetRevision for a given target. It sets
// Revision to the highest supported value, unless the provided revision number
// is larger.
-func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) {
+func TargetRevision(t *kernel.Task, revPtr hostarch.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) {
// Read in the target name and version.
var rev linux.XTGetRevision
if _, err := rev.CopyIn(t, revPtr); err != nil {
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 5f80d82ea..b2cc6be20 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
const matcherNameOwner = "owner"
@@ -60,7 +60,7 @@ func (ownerMarshaler) marshal(mr matcher) []byte {
}
buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
- return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo))
}
// unmarshal implements matchMaker.unmarshal.
@@ -72,7 +72,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.IPTOwnerInfo
- binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData)
nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
var owner OwnerMatcher
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index f2653d523..38b6491e2 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -19,11 +19,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
// ErrorTargetName is used to mark targets as error targets. Error targets
@@ -35,6 +35,11 @@ const ErrorTargetName = "ERROR"
// change the destination port and/or IP for packets.
const RedirectTargetName = "REDIRECT"
+// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should
+// be reached for only NAT table. These targets will change the source port
+// and/or IP for packets.
+const SNATTargetName = "SNAT"
+
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@@ -59,6 +64,13 @@ func init() {
registerTargetMaker(&nfNATTargetMaker{
NetworkProtocol: header.IPv6ProtocolNumber,
})
+
+ registerTargetMaker(&snatTargetMakerV4{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&snatTargetMakerV6{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
}
// The stack package provides some basic, useful targets for us. The following
@@ -131,6 +143,17 @@ func (rt *redirectTarget) id() targetID {
}
}
+type snatTarget struct {
+ stack.SNATTarget
+}
+
+func (st *snatTarget) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ }
+}
+
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
@@ -167,7 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte {
}
ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, usermem.ByteOrder, xt)
+ return binary.Marshal(ret, hostarch.ByteOrder, xt)
}
func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -177,7 +200,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
}
var standardTarget linux.XTStandardTarget
buf = buf[:linux.SizeOfXTStandardTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget)
if standardTarget.Verdict < 0 {
// A Verdict < 0 indicates a non-jump verdict.
@@ -223,7 +246,7 @@ func (*errorTargetMaker) marshal(target target) []byte {
copy(xt.Target.Name[:], ErrorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
- return binary.Marshal(ret, usermem.ByteOrder, xt)
+ return binary.Marshal(ret, hostarch.ByteOrder, xt)
}
func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -233,7 +256,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var errTgt linux.XTErrorTarget
buf = buf[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &errTgt)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt)
// Error targets are used in 2 cases:
// * An actual error case. These rules have an error named
@@ -281,7 +304,7 @@ func (*redirectTargetMaker) marshal(target target) []byte {
xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
xt.NfRange.RangeIPV4.MinPort = htons(rt.Port)
xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
- return binary.Marshal(ret, usermem.ByteOrder, xt)
+ return binary.Marshal(ret, hostarch.ByteOrder, xt)
}
func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -297,7 +320,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
var rt linux.XTRedirectTarget
buf = buf[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &rt)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &rt)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
target := redirectTarget{RedirectTarget: stack.RedirectTarget{
@@ -341,7 +364,7 @@ type nfNATTarget struct {
Range linux.NFNATRange
}
-const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
+const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
@@ -358,7 +381,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
- TargetSize: nfNATMarhsalledSize,
+ TargetSize: nfNATMarshalledSize,
},
Range: linux.NFNATRange{
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
@@ -371,12 +394,12 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
- ret := make([]byte, 0, nfNATMarhsalledSize)
- return binary.Marshal(ret, usermem.ByteOrder, nt)
+ ret := make([]byte, 0, nfNATMarshalledSize)
+ return binary.Marshal(ret, hostarch.ByteOrder, nt)
}
func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
- if size := nfNATMarhsalledSize; len(buf) < size {
+ if size := nfNATMarshalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
}
@@ -387,8 +410,8 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var natRange linux.NFNATRange
- buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
- binary.Unmarshal(buf, usermem.ByteOrder, &natRange)
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
+ binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
// We don't support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -418,6 +441,161 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
return &target, nil
}
+type snatTargetMakerV4 struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (st *snatTargetMakerV4) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ }
+}
+
+func (*snatTargetMakerV4) marshal(target target) []byte {
+ st := target.(*snatTarget)
+ // This is a snat target named snat.
+ xt := linux.XTSNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTSNATTarget,
+ },
+ }
+ copy(xt.Target.Name[:], SNATTargetName)
+
+ xt.NfRange.RangeSize = 1
+ xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeIPV4.MinPort = htons(st.Port)
+ xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
+ copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr)
+ copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr)
+ ret := make([]byte, 0, linux.SizeOfXTSNATTarget)
+ return binary.Marshal(ret, hostarch.ByteOrder, xt)
+}
+
+func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
+ if len(buf) < linux.SizeOfXTSNATTarget {
+ nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("snatTargetMakerV4: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var st linux.XTSNATTarget
+ buf = buf[:linux.SizeOfXTSNATTarget]
+ binary.Unmarshal(buf, hostarch.ByteOrder, &st)
+
+ // Copy linux.XTSNATTarget to stack.SNATTarget.
+ target := snatTarget{SNATTarget: stack.SNATTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}
+
+ // RangeSize should be 1.
+ nfRange := st.NfRange
+ if nfRange.RangeSize != 1 {
+ nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port,
+ // choose one automatically.
+ if nfRange.RangeIPV4.MinPort == 0 {
+ nflog("snatTargetMakerV4: snat target needs to specify a non-zero port")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+ if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
+ nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.Port = ntohs(nfRange.RangeIPV4.MinPort)
+
+ return &target, nil
+}
+
+type snatTargetMakerV6 struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (st *snatTargetMakerV6) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ revision: 1,
+ }
+}
+
+func (*snatTargetMakerV6) marshal(target target) []byte {
+ st := target.(*snatTarget)
+ nt := nfNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: nfNATMarshalledSize,
+ },
+ Range: linux.NFNATRange{
+ Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED,
+ },
+ }
+ copy(nt.Target.Name[:], SNATTargetName)
+ copy(nt.Range.MinAddr[:], st.Addr)
+ copy(nt.Range.MaxAddr[:], st.Addr)
+ nt.Range.MinProto = htons(st.Port)
+ nt.Range.MaxProto = nt.Range.MinProto
+
+ ret := make([]byte, 0, nfNATMarshalledSize)
+ return binary.Marshal(ret, hostarch.ByteOrder, nt)
+}
+
+func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
+ if size := nfNATMarshalledSize; len(buf) < size {
+ nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("snatTargetMakerV6: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var natRange linux.NFNATRange
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
+ binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
+
+ // TODO(gvisor.dev/issue/5689): Support port or address ranges.
+ if natRange.MinAddr != natRange.MaxAddr {
+ nflog("snatTargetMakerV6: MinAddr and MaxAddr are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+ if natRange.MinProto != natRange.MaxProto {
+ nflog("snatTargetMakerV6: MinProto and MaxProto are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags.
+ if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target := snatTarget{
+ SNATTarget: stack.SNATTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Addr: tcpip.Address(natRange.MinAddr[:]),
+ Port: ntohs(natRange.MinProto),
+ },
+ }
+
+ return &target, nil
+}
+
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
@@ -454,7 +632,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T
}
var target linux.XTEntryTarget
buf := optVal[:linux.SizeOfXTEntryTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &target)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &target)
return unmarshalTarget(target, filter, optVal)
}
@@ -487,11 +665,11 @@ func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook,
func ntohs(port uint16) uint16 {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, port)
- return usermem.ByteOrder.Uint16(buf)
+ return hostarch.ByteOrder.Uint16(buf)
}
func htons(port uint16) uint16 {
buf := make([]byte, 2)
- usermem.ByteOrder.PutUint16(buf, port)
+ hostarch.ByteOrder.PutUint16(buf, port)
return binary.BigEndian.Uint16(buf)
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 678d6b578..69557f515 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -19,9 +19,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
const matcherNameTCP = "tcp"
@@ -48,7 +48,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte {
DestinationPortEnd: matcher.destinationPortEnd,
}
buf := make([]byte, 0, linux.SizeOfXTTCP)
- return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp))
+ return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp))
}
// unmarshal implements matchMaker.unmarshal.
@@ -60,7 +60,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.XTTCP
- binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData)
+ binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData)
nflog("parseMatchers: parsed XTTCP: %+v", matchData)
if matchData.Option != 0 ||
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index f8568873f..6a60e6bd6 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -19,9 +19,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/usermem"
)
const matcherNameUDP = "udp"
@@ -48,7 +48,7 @@ func (udpMarshaler) marshal(mr matcher) []byte {
DestinationPortEnd: matcher.destinationPortEnd,
}
buf := make([]byte, 0, linux.SizeOfXTUDP)
- return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp))
+ return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp))
}
// unmarshal implements matchMaker.unmarshal.
@@ -60,7 +60,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
// For alignment reasons, the match's total size may exceed what's
// strictly necessary to hold matchData.
var matchData linux.XTUDP
- binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData)
+ binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData)
nflog("parseMatchers: parsed XTUDP: %+v", matchData)
if matchData.InverseFlags != 0 {
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 9313e1167..171b95c63 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index 0899c61d1..ab0e68af7 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -20,7 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// alignPad returns the length of padding required for alignment.
@@ -42,7 +42,7 @@ type Message struct {
func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
return &Message{
hdr: hdr,
- buf: binary.Marshal(nil, usermem.ByteOrder, hdr),
+ buf: binary.Marshal(nil, hostarch.ByteOrder, hdr),
}
}
@@ -58,7 +58,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
return
}
var hdr linux.NetlinkMessageHeader
- binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+ binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
// Msg portion.
totalMsgLen := int(hdr.Length)
@@ -105,7 +105,7 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
if !ok {
return nil, false
}
- binary.Unmarshal(msgBytes, usermem.ByteOrder, msg)
+ binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg)
numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
// Linux permits the last message not being aligned, just consume all of it.
@@ -126,7 +126,7 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
// calling Finalize.
func (m *Message) Finalize() []byte {
// Update length, which is the first 4 bytes of the header.
- usermem.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
+ hostarch.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
@@ -146,7 +146,7 @@ func (m *Message) putZeros(n int) {
// Put serializes v into the message.
func (m *Message) Put(v interface{}) {
- m.buf = binary.Marshal(m.buf, usermem.ByteOrder, v)
+ m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v)
}
// PutAttr adds v to the message as a netlink attribute.
@@ -251,7 +251,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest
if !ok {
return
}
- binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+ binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
if !ok {
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index d5ffc75ce..30c297149 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -222,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
}
var sa linux.SockAddrNetlink
- binary.Unmarshal(b[:linux.SockAddrNetlinkSize], usermem.ByteOrder, &sa)
+ binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa)
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument
@@ -327,7 +328,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -388,7 +389,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
if len(opt) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- size := usermem.ByteOrder.Uint32(opt)
+ size := hostarch.ByteOrder.Uint32(opt)
if size < minSendBufferSize {
size = minSendBufferSize
} else if size > maxSendBufferSize {
@@ -411,7 +412,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
if len(opt) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- passcred := usermem.ByteOrder.Uint32(opt)
+ passcred := hostarch.ByteOrder.Uint32(opt)
s.ep.SocketOptions().SetPassCred(passcred != 0)
return nil
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 244d99436..0b39a5b67 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 64e70ab9d..312f5f85a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -37,6 +37,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
@@ -241,6 +242,7 @@ var Metrics = tcpip.Stats{
FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."),
Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."),
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
+ FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
},
UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
@@ -600,7 +602,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
return syserr.ErrInvalidArgument
}
- family := usermem.ByteOrder.Uint16(sockaddr)
+ family := hostarch.ByteOrder.Uint16(sockaddr)
var addr tcpip.FullAddress
// Bind for AF_PACKET requires only family, protocol and ifindex.
@@ -611,7 +613,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < sockAddrLinkSize {
return syserr.ErrInvalidArgument
}
- binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a)
if a.Protocol != uint16(s.protocol) {
return syserr.ErrInvalidArgument
@@ -757,7 +759,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -793,7 +795,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -884,10 +886,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ size := ep.SocketOptions().GetReceiveBufferSize()
if size > math.MaxInt32 {
size = math.MaxInt32
@@ -1244,7 +1243,7 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name,
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if _, ok := ep.(tcpip.Endpoint); !ok {
log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
return nil, syserr.ErrUnknownProtocolOption
@@ -1392,7 +1391,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
if _, ok := ep.(tcpip.Endpoint); !ok {
log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
return nil, syserr.ErrUnknownProtocolOption
@@ -1602,7 +1601,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
}
s.readMu.Lock()
defer s.readMu.Unlock()
- s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ s.sockOptTimestamp = hostarch.ByteOrder.Uint32(optVal) != 0
return nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
@@ -1611,7 +1610,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
}
s.readMu.Lock()
defer s.readMu.Unlock()
- s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ s.sockOptInq = hostarch.ByteOrder.Uint32(optVal) != 0
return nil
}
@@ -1659,8 +1658,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
- ep.SocketOptions().SetSendBufferSize(int64(v), true)
+ v := hostarch.ByteOrder.Uint32(optVal)
+ ep.SocketOptions().SetSendBufferSize(int64(v), true /* notify */)
return nil
case linux.SO_RCVBUF:
@@ -1668,15 +1667,16 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
+ v := hostarch.ByteOrder.Uint32(optVal)
+ ep.SocketOptions().SetReceiveBufferSize(int64(v), true /* notify */)
+ return nil
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetReuseAddress(v != 0)
return nil
@@ -1685,7 +1685,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetReusePort(v != 0)
return nil
@@ -1714,7 +1714,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetBroadcast(v != 0)
return nil
@@ -1723,7 +1723,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetPassCred(v != 0)
return nil
@@ -1732,7 +1732,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetKeepAlive(v != 0)
return nil
@@ -1742,7 +1742,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v)
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1755,7 +1755,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v)
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1767,7 +1767,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
if v == 0 {
socket.SetSockOptEmitUnimplementedEvent(t, name)
@@ -1781,7 +1781,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetNoChecksum(v != 0)
return nil
@@ -1791,7 +1791,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Linger
- binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v)
+ binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v)
ep.SocketOptions().SetLinger(tcpip.LingerOption{
Enabled: v.OnOff != 0,
@@ -1824,7 +1824,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetDelayOption(v == 0)
return nil
@@ -1833,7 +1833,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetCorkOption(v != 0)
return nil
@@ -1842,7 +1842,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetQuickAck(v != 0)
return nil
@@ -1851,7 +1851,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
case linux.TCP_KEEPIDLE:
@@ -1859,7 +1859,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
if v < 1 || v > linux.MAX_TCP_KEEPIDLE {
return syserr.ErrInvalidArgument
}
@@ -1871,7 +1871,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
if v < 1 || v > linux.MAX_TCP_KEEPINTVL {
return syserr.ErrInvalidArgument
}
@@ -1883,7 +1883,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
if v < 1 || v > linux.MAX_TCP_KEEPCNT {
return syserr.ErrInvalidArgument
}
@@ -1894,7 +1894,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := int32(usermem.ByteOrder.Uint32(optVal))
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
if v < 0 {
return syserr.ErrInvalidArgument
}
@@ -1913,7 +1913,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
return syserr.ErrInvalidArgument
}
- v := int32(usermem.ByteOrder.Uint32(optVal))
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
opt := tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))
return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
@@ -1921,7 +1921,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- v := int32(usermem.ByteOrder.Uint32(optVal))
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
if v < 0 {
v = 0
}
@@ -1932,7 +1932,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v)))
@@ -1940,7 +1940,7 @@ func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name i
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v)))
@@ -1978,7 +1978,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return syserr.ErrInvalidEndpointState
}
- v := usermem.ByteOrder.Uint32(optVal)
+ v := hostarch.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetV6Only(v != 0)
return nil
@@ -2024,7 +2024,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- v := int32(usermem.ByteOrder.Uint32(optVal))
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
return nil
@@ -2033,7 +2033,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
- v := int32(usermem.ByteOrder.Uint32(optVal))
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
if v < -1 || v > 255 {
return syserr.ErrInvalidArgument
}
@@ -2117,12 +2117,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
if len(optVal) >= inetMulticastRequestWithNICSize {
var req linux.InetMulticastRequestWithNIC
- binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], usermem.ByteOrder, &req)
+ binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req)
return req, nil
}
var req linux.InetMulticastRequestWithNIC
- binary.Unmarshal(optVal[:inetMulticastRequestSize], usermem.ByteOrder, &req.InetMulticastRequest)
+ binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest)
return req, nil
}
@@ -2132,7 +2132,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse
}
var req linux.Inet6MulticastRequest
- binary.Unmarshal(optVal[:inet6MulticastRequestSize], usermem.ByteOrder, &req)
+ binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req)
return req, nil
}
@@ -2145,7 +2145,7 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
}
if len(buf) >= sizeOfInt32 {
- return int32(usermem.ByteOrder.Uint32(buf)), nil
+ return int32(hostarch.ByteOrder.Uint32(buf)), nil
}
return int32(buf[0]), nil
@@ -3007,7 +3007,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
if arg == linux.SIOCGIFNAME {
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
- index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
+ index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4]))
if iface, ok := stack.Interfaces()[index]; ok {
ifr.SetName(iface.Name)
return nil
@@ -3029,7 +3029,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
switch arg {
case linux.SIOCGIFINDEX:
// Copy out the index to the data.
- usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
+ hostarch.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
case linux.SIOCGIFHWADDR:
// Copy the hardware address out.
@@ -3042,7 +3042,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// sockaddr. sa_family contains the ARPHRD_* device type,
// sa_data the L2 hardware address starting from byte 0. Setting
// the hardware address is a privileged operation.
- usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType)
+ hostarch.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType)
n := copy(ifr.Data[2:], iface.Addr)
for i := 2 + n; i < len(ifr.Data); i++ {
ifr.Data[i] = 0 // Clear padding.
@@ -3055,7 +3055,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// Drop the flags that don't fit in the size that we need to return. This
// matches Linux behavior.
- usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
+ hostarch.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
@@ -3071,11 +3071,11 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFMETRIC:
// Gets the metric of the device. As per netdevice(7), this
// always just sets ifr_metric to 0.
- usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
+ hostarch.ByteOrder.PutUint32(ifr.Data[:4], 0)
case linux.SIOCGIFMTU:
// Gets the MTU of the device.
- usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
+ hostarch.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
case linux.SIOCGIFMAP:
// Gets the hardware parameters of the device.
@@ -3101,8 +3101,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
continue
}
// Populate ifr.ifr_netmask (type sockaddr).
- usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET))
- usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET))
+ hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0)
var mask uint32 = 0xffffffff << (32 - addr.PrefixLen)
// Netmask is expected to be returned as a big endian
// value.
@@ -3157,14 +3157,14 @@ func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.
// Populate ifr.ifr_addr.
ifr := linux.IFReq{}
ifr.SetName(iface.Name)
- usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
- usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
+ hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0)
copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
// Copy the ifr to userspace.
dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil {
+ if _, err := ifr.CopyOut(t, hostarch.Addr(dst)); err != nil {
return err
}
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index fc29f8f13..30f3ad153 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -17,6 +17,7 @@ package netstack
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -197,7 +198,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketVFS2 rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -245,7 +246,7 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by
}
s.readMu.Lock()
defer s.readMu.Unlock()
- s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ s.sockOptTimestamp = hostarch.ByteOrder.Uint32(optVal) != 0
return nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
@@ -254,7 +255,7 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by
}
s.readMu.Lock()
defer s.readMu.Unlock()
- s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ s.sockOptInq = hostarch.ByteOrder.Uint32(optVal) != 0
return nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 909341dcf..4c3d48096 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -216,7 +217,7 @@ type SocketOps interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux unix.
- GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux unix.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
@@ -356,7 +357,7 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
Type: fs.Socket,
DeviceID: d.DeviceID(),
InodeID: ino,
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
})
// Dirent name matches net/socket.c:sockfs_dname.
@@ -571,19 +572,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
switch family {
case unix.AF_INET:
var addr linux.SockAddrInet
- binary.Unmarshal(data[:unix.SizeofSockaddrInet4], usermem.ByteOrder, &addr)
+ binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr)
return &addr
case unix.AF_INET6:
var addr linux.SockAddrInet6
- binary.Unmarshal(data[:unix.SizeofSockaddrInet6], usermem.ByteOrder, &addr)
+ binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr)
return &addr
case unix.AF_UNIX:
var addr linux.SockAddrUnix
- binary.Unmarshal(data[:unix.SizeofSockaddrUnix], usermem.ByteOrder, &addr)
+ binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr)
return &addr
case unix.AF_NETLINK:
var addr linux.SockAddrNetlink
- binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], usermem.ByteOrder, &addr)
+ binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr)
return &addr
default:
panic(fmt.Sprintf("Unsupported socket family %v", family))
@@ -693,7 +694,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
}
// Get the rest of the fields based on the address family.
- switch family := usermem.ByteOrder.Uint16(addr); family {
+ switch family := hostarch.ByteOrder.Uint16(addr); family {
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
@@ -715,7 +716,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInetSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+ binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a)
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -728,7 +729,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInet6Size {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+ binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a)
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -744,7 +745,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrLinkSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a)
if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index ff53a26b7..c9cbefb3a 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -40,6 +40,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 159b8f90f..408dfb08d 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -130,7 +130,8 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv
}
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
@@ -175,8 +176,9 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */)
+ ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
return ep
}
@@ -299,8 +301,9 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
- ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits)
+ ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
+ ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize}
readQueue.InitRefs()
@@ -366,6 +369,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint
// to reflect this endpoint's send buffer size.
if bufSz := e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()); bufSz != e.ops.GetSendBufferSize() {
e.ops.SetSendBufferSize(bufSz, false /* notify */)
+ e.ops.SetReceiveBufferSize(bufSz, false /* notify */)
}
}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
index 590b0bd01..b20334d4f 100644
--- a/pkg/sentry/socket/unix/transport/connectioned_state.go
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -54,5 +54,5 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd
// afterLoad is invoked by stateify.
func (e *connectionedEndpoint) afterLoad() {
- e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index d0df28b59..61338728a 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -45,7 +45,8 @@ func NewConnectionless(ctx context.Context) Endpoint {
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go
index 2ef337ec8..1bb71baf7 100644
--- a/pkg/sentry/socket/unix/transport/connectionless_state.go
+++ b/pkg/sentry/socket/unix/transport/connectionless_state.go
@@ -16,5 +16,5 @@ package transport
// afterLoad is invoked by stateify.
func (e *connectionlessEndpoint) afterLoad() {
- e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 0c5f5ab42..837ab4fde 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -868,11 +868,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- }
+ log.Warningf("Unsupported socket option: %d", opt)
return nil
}
@@ -905,19 +901,6 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
return int(v), nil
- case tcpip.ReceiveBufferSizeOption:
- e.Lock()
- if e.receiver == nil {
- e.Unlock()
- return -1, &tcpip.ErrNotConnected{}
- }
- v := e.receiver.RecvMaxQueueSize()
- e.Unlock()
- if v < 0 {
- return -1, &tcpip.ErrQueueSizeNotSupported{}
- }
- return int(v), nil
-
default:
log.Warningf("Unsupported socket option: %d", opt)
return -1, &tcpip.ErrUnknownProtocolOption{}
@@ -1029,3 +1012,15 @@ func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption {
Max: maxBufferSize,
}
}
+
+// getReceiveBufferLimits implements tcpip.GetReceiveBufferLimits.
+//
+// We define min, max and default values for unix socket implementation. Unix
+// sockets do not use receive buffer.
+func getReceiveBufferLimits(tcpip.StackHandler) tcpip.ReceiveBufferSizeOption {
+ return tcpip.ReceiveBufferSizeOption{
+ Min: minimumBufferSize,
+ Default: defaultBufferSize,
+ Max: maxBufferSize,
+ }
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index b22f7973a..db7b1affe 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -192,7 +193,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 7890d1048..c39e317ff 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
@@ -112,7 +113,7 @@ func (s *SocketVFS2) Release(ctx context.Context) {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen)
}
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 1b7fd2232..2ebd77f82 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -28,6 +28,7 @@ go_library(
"//pkg/binary",
"//pkg/bits",
"//pkg/eventchannel",
+ "//pkg/hostarch",
"//pkg/marshal/primitive",
"//pkg/seccomp",
"//pkg/sentry/arch",
@@ -35,7 +36,6 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/syscalls/linux",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
index ae3b998c8..48650e3f9 100644
--- a/pkg/sentry/strace/epoll.go
+++ b/pkg/sentry/strace/epoll.go
@@ -21,10 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
-func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+func epollEvent(t *kernel.Task, eventAddr hostarch.Addr) string {
var e linux.EpollEvent
if _, err := e.CopyIn(t, eventAddr); err != nil {
return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
@@ -35,7 +36,7 @@ func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
return sb.String()
}
-func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+func epollEvents(t *kernel.Task, eventsAddr hostarch.Addr, numEvents, maxBytes uint64) string {
var sb strings.Builder
fmt.Fprintf(&sb, "%#x {", eventsAddr)
addr := eventsAddr
diff --git a/pkg/sentry/strace/poll.go b/pkg/sentry/strace/poll.go
index 074e80f9b..572a8b50b 100644
--- a/pkg/sentry/strace/poll.go
+++ b/pkg/sentry/strace/poll.go
@@ -22,7 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// PollEventSet is the set of poll(2) event flags.
@@ -52,7 +53,7 @@ func pollFD(t *kernel.Task, pfd *linux.PollFD, post bool) string {
return fmt.Sprintf("{FD: %s, Events: %s, REvents: %s}", fd(t, pfd.FD), PollEventSet.Parse(uint64(pfd.Events)), revents)
}
-func pollFDs(t *kernel.Task, addr usermem.Addr, nfds uint, post bool) string {
+func pollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint, post bool) string {
if addr == 0 {
return "null"
}
diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go
index 3a4c32aa0..e6e928157 100644
--- a/pkg/sentry/strace/select.go
+++ b/pkg/sentry/strace/select.go
@@ -19,7 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
func fdsFromSet(t *kernel.Task, set []byte) []int {
@@ -35,7 +36,7 @@ func fdsFromSet(t *kernel.Task, set []byte) []int {
return fds
}
-func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string {
+func fdSet(t *kernel.Task, nfds int, addr hostarch.Addr) string {
if nfds < 0 {
return fmt.Sprintf("%#x (negative nfds)", addr)
}
diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go
index c41f36e3f..e5b379a20 100644
--- a/pkg/sentry/strace/signal.go
+++ b/pkg/sentry/strace/signal.go
@@ -21,7 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// signalNames contains the names of all named signals.
@@ -100,7 +101,7 @@ var sigActionFlags = abi.FlagSet{
},
}
-func sigSet(t *kernel.Task, addr usermem.Addr) string {
+func sigSet(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -110,7 +111,7 @@ func sigSet(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x (error copying sigset: %v)", addr, err)
}
- set := linux.SignalSet(usermem.ByteOrder.Uint64(b[:]))
+ set := linux.SignalSet(hostarch.ByteOrder.Uint64(b[:]))
return fmt.Sprintf("%#x %s", addr, formatSigSet(set))
}
@@ -124,7 +125,7 @@ func formatSigSet(set linux.SignalSet) string {
return fmt.Sprintf("[%v]", strings.Join(signals, " "))
}
-func sigAction(t *kernel.Task, addr usermem.Addr) string {
+func sigAction(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index d943a7cb1..e5b7f9b96 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -26,7 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// SocketFamily are the possible socket(2) families.
@@ -161,7 +162,7 @@ var controlMessageType = map[int32]string{
linux.SO_TIMESTAMP: "SO_TIMESTAMP",
}
-func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) string {
+func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string {
if length > maxBytes {
return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length)
}
@@ -180,7 +181,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
}
var h linux.ControlMessageHeader
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h)
var skipData bool
level := "SOL_SOCKET"
@@ -230,7 +231,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
numRights := rightsSize / linux.SizeOfControlMessageRight
fds := make(linux.ControlMessageRights, numRights)
- binary.Unmarshal(buf[i:i+rightsSize], usermem.ByteOrder, &fds)
+ binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds)
rights := make([]string, 0, len(fds))
for _, fd := range fds {
@@ -257,7 +258,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
}
var creds linux.ControlMessageCredentials
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds)
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
@@ -281,7 +282,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
}
var tv linux.Timeval
- binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &tv)
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv)
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
@@ -301,7 +302,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
}
-func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string {
+func msghdr(t *kernel.Task, addr hostarch.Addr, printContent bool, maxBytes uint64) string {
var msg slinux.MessageHeader64
if _, err := msg.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err)
@@ -311,17 +312,17 @@ func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint6
addr,
msg.Name,
msg.NameLen,
- iovecs(t, usermem.Addr(msg.Iov), int(msg.IovLen), printContent, maxBytes),
+ iovecs(t, hostarch.Addr(msg.Iov), int(msg.IovLen), printContent, maxBytes),
)
if printContent {
- s = fmt.Sprintf("%s, control={%s}", s, cmsghdr(t, usermem.Addr(msg.Control), msg.ControlLen, maxBytes))
+ s = fmt.Sprintf("%s, control={%s}", s, cmsghdr(t, hostarch.Addr(msg.Control), msg.ControlLen, maxBytes))
} else {
s = fmt.Sprintf("%s, control=%#x, control_len=%d", s, msg.Control, msg.ControlLen)
}
return fmt.Sprintf("%s, flags=%d}", s, msg.Flags)
}
-func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
+func sockAddr(t *kernel.Task, addr hostarch.Addr, length uint32) string {
if addr == 0 {
return "null"
}
@@ -335,7 +336,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
if len(b) < 2 {
return fmt.Sprintf("%#x {address too short: %d bytes}", addr, len(b))
}
- family := usermem.ByteOrder.Uint16(b)
+ family := hostarch.ByteOrder.Uint16(b)
familyStr := SocketFamily.Parse(uint64(family))
@@ -362,7 +363,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
}
}
-func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) string {
+func postSockAddr(t *kernel.Task, addr hostarch.Addr, lengthPtr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -379,14 +380,14 @@ func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) str
return sockAddr(t, addr, l)
}
-func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) {
+func copySockLen(t *kernel.Task, addr hostarch.Addr) (uint32, error) {
// socklen_t is 32-bits.
var l primitive.Uint32
_, err := l.CopyIn(t, addr)
return uint32(l), err
}
-func sockLenPointer(t *kernel.Task, addr usermem.Addr) string {
+func sockLenPointer(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -420,7 +421,7 @@ func sockFlags(flags int32) string {
return SocketFlagSet.Parse(uint64(flags))
}
-func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen usermem.Addr, maximumBlobSize uint, rval uintptr) string {
+func getSockOptVal(t *kernel.Task, level, optname uint64, optVal hostarch.Addr, optLen hostarch.Addr, maximumBlobSize uint, rval uintptr) string {
if int(rval) < 0 {
return hexNum(uint64(optVal))
}
@@ -434,7 +435,7 @@ func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, o
return sockOptVal(t, level, optname, optVal, uint64(l), maximumBlobSize)
}
-func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string {
+func sockOptVal(t *kernel.Task, level, optname uint64, optVal hostarch.Addr, optLen uint64, maximumBlobSize uint) string {
switch optLen {
case 1:
var v primitive.Uint8
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 396744597..ec5d5f846 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -32,7 +32,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
pb "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// DefaultLogMaximumSize is the default LogMaximumSize.
@@ -62,7 +63,7 @@ func hexArg(arg arch.SyscallArgument) string {
return hexNum(arg.Uint64())
}
-func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, maxBytes uint64) string {
+func iovecs(t *kernel.Task, addr hostarch.Addr, iovcnt int, printContent bool, maxBytes uint64) string {
if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
return fmt.Sprintf("%#x (error decoding iovecs: invalid iovcnt)", addr)
}
@@ -107,7 +108,7 @@ func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, ma
return fmt.Sprintf("%#x %s", addr, strings.Join(iovs, ", "))
}
-func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) string {
+func dump(t *kernel.Task, addr hostarch.Addr, size uint, maximumBlobSize uint) string {
origSize := size
if size > maximumBlobSize {
size = maximumBlobSize
@@ -131,7 +132,7 @@ func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) st
return fmt.Sprintf("%#x %q%s", addr, b[:amt], dot)
}
-func path(t *kernel.Task, addr usermem.Addr) string {
+func path(t *kernel.Task, addr hostarch.Addr) string {
path, err := t.CopyInString(addr, linux.PATH_MAX)
if err != nil {
return fmt.Sprintf("%#x (error decoding path: %s)", addr, err)
@@ -196,7 +197,7 @@ func fdVFS2(t *kernel.Task, fd int32) string {
return fmt.Sprintf("%#x %s", fd, name)
}
-func fdpair(t *kernel.Task, addr usermem.Addr) string {
+func fdpair(t *kernel.Task, addr hostarch.Addr) string {
var fds [2]int32
_, err := primitive.CopyInt32SliceIn(t, addr, fds[:])
if err != nil {
@@ -206,7 +207,7 @@ func fdpair(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x [%d %d]", addr, fds[0], fds[1])
}
-func uname(t *kernel.Task, addr usermem.Addr) string {
+func uname(t *kernel.Task, addr hostarch.Addr) string {
var u linux.UtsName
if _, err := u.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err)
@@ -215,7 +216,7 @@ func uname(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x %s", addr, u)
}
-func utimensTimespec(t *kernel.Task, addr usermem.Addr) string {
+func utimensTimespec(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -237,7 +238,7 @@ func utimensTimespec(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x {sec=%v nsec=%s}", addr, tim.Sec, ns)
}
-func timespec(t *kernel.Task, addr usermem.Addr) string {
+func timespec(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -249,7 +250,7 @@ func timespec(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec)
}
-func timeval(t *kernel.Task, addr usermem.Addr) string {
+func timeval(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -262,7 +263,7 @@ func timeval(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x {sec=%v usec=%v}", addr, tim.Sec, tim.Usec)
}
-func utimbuf(t *kernel.Task, addr usermem.Addr) string {
+func utimbuf(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -275,7 +276,7 @@ func utimbuf(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x {actime=%v, modtime=%v}", addr, utim.Actime, utim.Modtime)
}
-func stat(t *kernel.Task, addr usermem.Addr) string {
+func stat(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -287,27 +288,27 @@ func stat(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec))
}
-func itimerval(t *kernel.Task, addr usermem.Addr) string {
+func itimerval(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
interval := timeval(t, addr)
- value := timeval(t, addr+usermem.Addr((*linux.Timeval)(nil).SizeBytes()))
+ value := timeval(t, addr+hostarch.Addr((*linux.Timeval)(nil).SizeBytes()))
return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
}
-func itimerspec(t *kernel.Task, addr usermem.Addr) string {
+func itimerspec(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
interval := timespec(t, addr)
- value := timespec(t, addr+usermem.Addr((*linux.Timespec)(nil).SizeBytes()))
+ value := timespec(t, addr+hostarch.Addr((*linux.Timespec)(nil).SizeBytes()))
return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value)
}
-func stringVector(t *kernel.Task, addr usermem.Addr) string {
+func stringVector(t *kernel.Task, addr hostarch.Addr) string {
vec, err := t.CopyInVector(addr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
if err != nil {
return fmt.Sprintf("%#x {error copying vector: %v}", addr, err)
@@ -323,7 +324,7 @@ func stringVector(t *kernel.Task, addr usermem.Addr) string {
return s
}
-func rusage(t *kernel.Task, addr usermem.Addr) string {
+func rusage(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -335,7 +336,7 @@ func rusage(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x %+v", addr, ru)
}
-func capHeader(t *kernel.Task, addr usermem.Addr) string {
+func capHeader(t *kernel.Task, addr hostarch.Addr) string {
if addr == 0 {
return "null"
}
@@ -360,7 +361,7 @@ func capHeader(t *kernel.Task, addr usermem.Addr) string {
return fmt.Sprintf("%#x {Version: %s, Pid: %d}", addr, version, hdr.Pid)
}
-func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
+func capData(t *kernel.Task, hdrAddr, dataAddr hostarch.Addr) string {
if dataAddr == 0 {
return "null"
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 3dcf36a96..408a6c422 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -64,6 +64,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/bpf",
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index efec93f73..37121186a 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -33,6 +33,14 @@ var (
partialResultOnce sync.Once
)
+// incrementPartialResultMetric increments PartialResultMetric by calling
+// Increment(). This is added as the func Do() which is called below requires
+// us to pass a function which does not take any arguments, whereas Increment()
+// takes a variadic number of arguments.
+func incrementPartialResultMetric() {
+ partialResultMetric.Increment()
+}
+
// HandleIOErrorVFS2 handles special error cases for partial results. For some
// errors, we may consume the error and return only the partial read/write.
//
@@ -48,7 +56,7 @@ func HandleIOErrorVFS2(ctx context.Context, partialResult bool, ioerr, intr erro
root := vfs.RootFromContext(ctx)
name, _ := fs.PathnameWithDeleted(ctx, root, f.VirtualDentry())
log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, ioerr, ioerr, op, name)
- partialResultOnce.Do(partialResultMetric.Increment)
+ partialResultOnce.Do(incrementPartialResultMetric)
}
return nil
}
@@ -66,7 +74,7 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o
// An unknown error is encountered with a partial read/write.
name, _ := f.Dirent.FullName(nil /* ignore chroot */)
log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, ioerr, ioerr, op, name, f.FileOperations)
- partialResultOnce.Do(partialResultMetric.Increment)
+ partialResultOnce.Do(incrementPartialResultMetric)
}
return nil
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index ac53a0c0e..2d2212605 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -18,11 +18,11 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/syscalls"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -405,7 +405,7 @@ var AMD64 = &kernel.SyscallTable{
434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
},
- Emulate: map[usermem.Addr]uintptr{
+ Emulate: map[hostarch.Addr]uintptr{
0xffffffffff600000: 96, // vsyscall gettimeofday(2)
0xffffffffff600400: 201, // vsyscall time(2)
0xffffffffff600800: 309, // vsyscall getcpu(2)
@@ -723,7 +723,7 @@ var ARM64 = &kernel.SyscallTable{
434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
},
- Emulate: map[usermem.Addr]uintptr{},
+ Emulate: map[hostarch.Addr]uintptr{},
Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
t.Kernel().EmitUnimplementedEvent(t)
return 0, syserror.ENOSYS
diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go
index 434559b80..e8c2d8f9e 100644
--- a/pkg/sentry/syscalls/linux/sigset.go
+++ b/pkg/sentry/syscalls/linux/sigset.go
@@ -16,9 +16,9 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// CopyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and
@@ -27,7 +27,7 @@ import (
// TODO(gvisor.dev/issue/1624): This is only exported because
// syscalls/vfs2/signal.go depends on it. Once vfs1 is deleted and the vfs2
// syscalls are moved into this package, then they can be unexported.
-func CopyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.SignalSet, error) {
+func CopyInSigSet(t *kernel.Task, sigSetAddr hostarch.Addr, size uint) (linux.SignalSet, error) {
if size != linux.SignalSetSize {
return 0, syserror.EINVAL
}
@@ -35,14 +35,14 @@ func CopyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.Sig
if _, err := t.CopyInBytes(sigSetAddr, b); err != nil {
return 0, err
}
- mask := usermem.ByteOrder.Uint64(b[:])
+ mask := hostarch.ByteOrder.Uint64(b[:])
return linux.SignalSet(mask) &^ kernel.UnblockableSignals, nil
}
// copyOutSigSet copies out a sigset_t.
-func copyOutSigSet(t *kernel.Task, sigSetAddr usermem.Addr, mask linux.SignalSet) error {
+func copyOutSigSet(t *kernel.Task, sigSetAddr hostarch.Addr, mask linux.SignalSet) error {
b := t.CopyScratchBuffer(8)
- usermem.ByteOrder.PutUint64(b, uint64(mask))
+ hostarch.ByteOrder.PutUint64(b, uint64(mask))
_, err := t.CopyOutBytes(sigSetAddr, b)
return err
}
@@ -55,15 +55,15 @@ func copyOutSigSet(t *kernel.Task, sigSetAddr usermem.Addr, mask linux.SignalSet
// };
//
// and returns sigset_addr and size.
-func copyInSigSetWithSize(t *kernel.Task, addr usermem.Addr) (usermem.Addr, uint, error) {
+func copyInSigSetWithSize(t *kernel.Task, addr hostarch.Addr) (hostarch.Addr, uint, error) {
switch t.Arch().Width() {
case 8:
in := t.CopyScratchBuffer(16)
if _, err := t.CopyInBytes(addr, in); err != nil {
return 0, 0, err
}
- maskAddr := usermem.Addr(usermem.ByteOrder.Uint64(in[0:]))
- maskSize := uint(usermem.ByteOrder.Uint64(in[8:]))
+ maskAddr := hostarch.Addr(hostarch.ByteOrder.Uint64(in[0:]))
+ maskSize := uint(hostarch.ByteOrder.Uint64(in[8:]))
return maskAddr, maskSize, nil
default:
return 0, 0, syserror.ENOSYS
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index c2285f796..70e8569a8 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -17,6 +17,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -152,7 +153,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
}
// Keep rolling.
- eventsAddr += usermem.Addr(linux.IOEventSize)
+ eventsAddr += hostarch.Addr(linux.IOEventSize)
}
// Everything finished.
@@ -191,12 +192,12 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error)
// I/O.
switch cb.OpCode {
case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
- return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ return t.SingleIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
- return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ return t.IovecsIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
@@ -219,7 +220,7 @@ func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// LINT.IfChange
-func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback {
+func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr hostarch.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback {
return func(ctx context.Context) {
if actx.Dead() {
actx.CancelPendingRequest()
@@ -264,7 +265,7 @@ func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linu
}
// submitCallback processes a single callback.
-func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr hostarch.Addr) error {
file := t.GetFile(cb.FD)
if file == nil {
// File not found.
@@ -339,7 +340,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
for i := int32(0); i < nrEvents; i++ {
// Copy in the callback address.
- var cbAddr usermem.Addr
+ var cbAddr hostarch.Addr
switch t.Arch().Width() {
case 8:
var cbAddrP primitive.Uint64
@@ -351,7 +352,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Nothing done.
return 0, nil, err
}
- cbAddr = usermem.Addr(cbAddrP)
+ cbAddr = hostarch.Addr(cbAddrP)
default:
return 0, nil, syserror.ENOSYS
}
@@ -379,7 +380,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Advance to the next one.
- addr += usermem.Addr(t.Arch().Width())
+ addr += hostarch.Addr(t.Arch().Width())
}
return uintptr(nrEvents), nil, nil
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index fd9649340..9cd238efd 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -18,6 +18,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -29,7 +30,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// fileOpAt performs an operation on the second last component in the path.
@@ -115,7 +115,7 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro
}
// copyInPath copies a path in.
-func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string, dirPath bool, err error) {
+func copyInPath(t *kernel.Task, addr hostarch.Addr, allowEmpty bool) (path string, dirPath bool, err error) {
path, err = t.CopyInString(addr, linux.PATH_MAX)
if err != nil {
return "", false, err
@@ -133,7 +133,7 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string
// LINT.IfChange
-func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
+func openAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return 0, err
@@ -208,7 +208,7 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
return fd, err // Use result in frame.
}
-func mknodAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
+func mknodAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMode) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return err
@@ -301,7 +301,7 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, mknodAt(t, dirFD, path, mode)
}
-func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode linux.FileMode) (fd uintptr, err error) {
+func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode linux.FileMode) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return 0, err
@@ -515,7 +515,7 @@ func (ac accessContext) Value(key interface{}) interface{} {
}
}
-func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode uint) error {
+func accessAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode uint) error {
const rOK = 4
const wOK = 2
const xOK = 1
@@ -694,7 +694,7 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// Top it off with a terminator.
- _, err = t.CopyOutBytes(addr+usermem.Addr(bytes), []byte("\x00"))
+ _, err = t.CopyOutBytes(addr+hostarch.Addr(bytes), []byte("\x00"))
return uintptr(bytes + 1), nil, err
}
@@ -1164,7 +1164,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// LINT.IfChange
-func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
+func mkdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return err
@@ -1216,7 +1216,7 @@ func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, mkdirAt(t, dirFD, addr, mode)
}
-func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
+func rmdirAt(t *kernel.Task, dirFD int32, addr hostarch.Addr) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return err
@@ -1256,7 +1256,7 @@ func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, rmdirAt(t, linux.AT_FDCWD, addr)
}
-func symlinkAt(t *kernel.Task, dirFD int32, newAddr usermem.Addr, oldAddr usermem.Addr) error {
+func symlinkAt(t *kernel.Task, dirFD int32, newAddr hostarch.Addr, oldAddr hostarch.Addr) error {
newPath, dirPath, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
return err
@@ -1341,7 +1341,7 @@ func mayLinkAt(t *kernel.Task, target *fs.Inode) error {
// linkAt creates a hard link to the target specified by oldDirFD and oldAddr,
// specified by newDirFD and newAddr. If resolve is true, then the symlinks
// will be followed when evaluating the target.
-func linkAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr, resolve, allowEmpty bool) error {
+func linkAt(t *kernel.Task, oldDirFD int32, oldAddr hostarch.Addr, newDirFD int32, newAddr hostarch.Addr, resolve, allowEmpty bool) error {
oldPath, _, err := copyInPath(t, oldAddr, allowEmpty)
if err != nil {
return err
@@ -1448,7 +1448,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// LINT.IfChange
-func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
+func readlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, bufAddr hostarch.Addr, size uint) (copied uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return 0, err
@@ -1511,7 +1511,7 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// LINT.IfChange
-func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
+func unlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return err
@@ -1728,7 +1728,7 @@ func chown(t *kernel.Task, d *fs.Dirent, uid auth.UID, gid auth.GID) error {
return nil
}
-func chownAt(t *kernel.Task, fd int32, addr usermem.Addr, resolve, allowEmpty bool, uid auth.UID, gid auth.GID) error {
+func chownAt(t *kernel.Task, fd int32, addr hostarch.Addr, resolve, allowEmpty bool, uid auth.UID, gid auth.GID) error {
path, _, err := copyInPath(t, addr, allowEmpty)
if err != nil {
return err
@@ -1815,7 +1815,7 @@ func chmod(t *kernel.Task, d *fs.Dirent, mode linux.FileMode) error {
return nil
}
-func chmodAt(t *kernel.Task, fd int32, addr usermem.Addr, mode linux.FileMode) error {
+func chmodAt(t *kernel.Task, fd int32, addr hostarch.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
return err
@@ -1866,7 +1866,7 @@ func defaultSetToSystemTimeSpec() fs.TimeSpec {
}
}
-func utimes(t *kernel.Task, dirFD int32, addr usermem.Addr, ts fs.TimeSpec, resolve bool) error {
+func utimes(t *kernel.Task, dirFD int32, addr hostarch.Addr, ts fs.TimeSpec, resolve bool) error {
setTimestamp := func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
// Does the task own the file?
if !d.Inode.CheckOwnership(t) {
@@ -2030,7 +2030,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// LINT.IfChange
-func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
+func renameAt(t *kernel.Task, oldDirFD int32, oldAddr hostarch.Addr, newDirFD int32, newAddr hostarch.Addr) error {
newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
return err
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index f39ce0639..eeea1613b 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -18,11 +18,11 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// futexWaitRestartBlock encapsulates the state required to restart futex(2)
@@ -41,7 +41,7 @@ type futexWaitRestartBlock struct {
// Restart implements kernel.SyscallRestartBlock.Restart.
func (f *futexWaitRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
- return futexWaitDuration(t, f.duration, false, usermem.Addr(f.addr), f.private, f.val, f.mask)
+ return futexWaitDuration(t, f.duration, false, hostarch.Addr(f.addr), f.private, f.val, f.mask)
}
// futexWaitAbsolute performs a FUTEX_WAIT_BITSET, blocking until the wait is
@@ -51,7 +51,7 @@ func (f *futexWaitRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
//
// If blocking is interrupted, the syscall is restarted with the original
// arguments.
-func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) {
+func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, forever bool, addr hostarch.Addr, private bool, val, mask uint32) (uintptr, error) {
w := t.FutexWaiter()
err := t.Futex().WaitPrepare(w, t, addr, private, val, mask)
if err != nil {
@@ -87,7 +87,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo
// syscall. If forever is true, the syscall is restarted with the original
// arguments. If forever is false, duration is a relative timeout and the
// syscall is restarted with the remaining timeout.
-func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, addr usermem.Addr, private bool, val, mask uint32) (uintptr, error) {
+func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, addr hostarch.Addr, private bool, val, mask uint32) (uintptr, error) {
w := t.FutexWaiter()
err := t.Futex().WaitPrepare(w, t, addr, private, val, mask)
if err != nil {
@@ -124,7 +124,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add
return 0, syserror.ERESTART_RESTARTBLOCK
}
-func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.Addr, private bool) error {
+func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr hostarch.Addr, private bool) error {
w := t.FutexWaiter()
locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, false)
if err != nil {
@@ -152,7 +152,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.A
return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
}
-func tryLockPI(t *kernel.Task, addr usermem.Addr, private bool) error {
+func tryLockPI(t *kernel.Task, addr hostarch.Addr, private bool) error {
w := t.FutexWaiter()
locked, err := t.Futex().LockPI(w, t, addr, uint32(t.ThreadID()), private, true)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index b25f7d881..bbba71d8f 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -19,6 +19,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -62,7 +63,7 @@ func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getdents implements the core of getdents(2)/getdents64(2).
// f is the syscall implementation dirent serialization function.
-func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dirent, io.Writer) (int, error)) (uintptr, error) {
+func getdents(t *kernel.Task, fd int32, addr hostarch.Addr, size int, f func(*dirent, io.Writer) (int, error)) (uintptr, error) {
dir := t.GetFile(fd)
if dir == nil {
return 0, syserror.EBADF
diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go
index 9b4a5c3f1..6d27f4292 100644
--- a/pkg/sentry/syscalls/linux/sys_mempolicy.go
+++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -31,7 +32,7 @@ const (
allowedNodemask = (1 << maxNodes) - 1
)
-func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, error) {
+func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64, error) {
// "nodemask points to a bit mask of node IDs that contains up to maxnode
// bits. The bit mask size is rounded to the next multiple of
// sizeof(unsigned long), but the kernel will use bits only up to maxnode.
@@ -41,7 +42,7 @@ func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64,
// because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses
// maxnode-1, not maxnode, as the number of bits.
bits := maxnode - 1
- if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0
+ if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
return 0, syserror.EINVAL
}
if bits == 0 {
@@ -53,7 +54,7 @@ func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64,
if _, err := t.CopyInBytes(addr, buf); err != nil {
return 0, err
}
- val := usermem.ByteOrder.Uint64(buf)
+ val := hostarch.ByteOrder.Uint64(buf)
// Check that only allowed bits in the first unsigned long in the nodemask
// are set.
if val&^allowedNodemask != 0 {
@@ -68,11 +69,11 @@ func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64,
return val, nil
}
-func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint64) error {
+func copyOutNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32, val uint64) error {
// mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of
// bits.
bits := maxnode - 1
- if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0
+ if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
return syserror.EINVAL
}
if bits == 0 {
@@ -80,7 +81,7 @@ func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint
}
// Copy out the first unsigned long in the nodemask.
buf := t.CopyScratchBuffer(8)
- usermem.ByteOrder.PutUint64(buf, val)
+ hostarch.ByteOrder.PutUint64(buf, val)
if _, err := t.CopyOutBytes(addr, buf); err != nil {
return err
}
@@ -258,7 +259,7 @@ func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
-func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask usermem.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) {
+func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask hostarch.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) {
flags := linux.NumaPolicy(modeWithFlags & linux.MPOL_MODE_FLAGS)
mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS)
if flags == linux.MPOL_MODE_FLAGS {
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index cd8dfdfa4..70da0707d 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -23,7 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Brk implements linux syscall brk(2).
@@ -61,12 +62,12 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Unmap: fixed,
Map32Bit: map32bit,
Private: private,
- Perms: usermem.AccessType{
+ Perms: hostarch.AccessType{
Read: linux.PROT_READ&prot != 0,
Write: linux.PROT_WRITE&prot != 0,
Execute: linux.PROT_EXEC&prot != 0,
},
- MaxPerms: usermem.AnyAccess,
+ MaxPerms: hostarch.AnyAccess,
GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
Precommit: linux.MAP_POPULATE&flags != 0,
}
@@ -160,7 +161,7 @@ func Mremap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
func Mprotect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
length := args[1].Uint64()
prot := args[2].Int()
- err := t.MemoryManager().MProtect(args[0].Pointer(), length, usermem.AccessType{
+ err := t.MemoryManager().MProtect(args[0].Pointer(), length, hostarch.AccessType{
Read: linux.PROT_READ&prot != 0,
Write: linux.PROT_WRITE&prot != 0,
Execute: linux.PROT_EXEC&prot != 0,
@@ -183,7 +184,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, nil
}
// Not explicitly stated: length need not be page-aligned.
- lenAddr, ok := usermem.Addr(length).RoundUp()
+ lenAddr, ok := hostarch.Addr(length).RoundUp()
if !ok {
return 0, nil, syserror.EINVAL
}
@@ -232,7 +233,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// "The length argument need not be a multiple of the page size, but since
// residency information is returned for whole pages, length is effectively
// rounded up to the next multiple of the page size." - mincore(2)
- la, ok := usermem.Addr(length).RoundUp()
+ la, ok := hostarch.Addr(length).RoundUp()
if !ok {
return 0, nil, syserror.ENOMEM
}
@@ -247,7 +248,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if mapped != uint64(la) {
return 0, nil, syserror.ENOMEM
}
- resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize))
+ resident := bytes.Repeat([]byte{1}, int(mapped/hostarch.PageSize))
_, err := t.CopyOutBytes(vec, resident)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go
index bd0633564..864d2138c 100644
--- a/pkg/sentry/syscalls/linux/sys_mount.go
+++ b/pkg/sentry/syscalls/linux/sys_mount.go
@@ -20,7 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Mount implements Linux syscall mount(2).
@@ -31,7 +32,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
flags := args[3].Uint64()
dataAddr := args[4].Pointer()
- fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ fsType, err := t.CopyInString(typeAddr, hostarch.PageSize)
if err != nil {
return 0, nil, err
}
@@ -52,7 +53,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// character placement, and the address is passed to each file system.
// Most file systems always treat this data as a string, though, and so
// do all of the ones we implement.
- data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ data, err = t.CopyInString(dataAddr, hostarch.PageSize)
if err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index f7135ea46..d95034347 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -16,19 +16,19 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// LINT.IfChange
// pipe2 implements the actual system call with flags.
-func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
+func pipe2(t *kernel.Task, addr hostarch.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
return 0, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index 254f4c9f9..da548a14a 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -18,13 +18,13 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -155,7 +155,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
}
// CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
-func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+func CopyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) {
if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
return nil, syserror.EINVAL
}
@@ -170,7 +170,7 @@ func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD
return pfd, nil
}
-func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
pfd, err := CopyInPollFDs(t, addr, nfds)
if err != nil {
return timeout, 0, err
@@ -198,7 +198,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
}
// CopyInFDSet copies an fd set from select(2)/pselect(2).
-func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
set := make([]byte, nBytes)
if addr != 0 {
@@ -215,7 +215,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy
return set, nil
}
-func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) {
if nfds < 0 || nfds > fileCap {
return 0, syserror.EINVAL
}
@@ -365,7 +365,7 @@ func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration)
// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
//
// startNs must be from CLOCK_MONOTONIC.
-func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr hostarch.Addr) error {
if timeout <= 0 {
return nil
}
@@ -377,7 +377,7 @@ func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.D
// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
//
// startNs must be from CLOCK_MONOTONIC.
-func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr hostarch.Addr) error {
if timeout <= 0 {
return nil
}
@@ -391,7 +391,7 @@ func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Du
//
// +stateify savable
type pollRestartBlock struct {
- pfdAddr usermem.Addr
+ pfdAddr hostarch.Addr
nfds uint
timeout time.Duration
}
@@ -401,7 +401,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
return poll(t, p.pfdAddr, p.nfds, p.timeout)
}
-func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
// On an interrupt poll(2) is restarted with the remaining timeout.
if err == syserror.EINTR {
diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go
index c0aa0fd60..ae545f80f 100644
--- a/pkg/sentry/syscalls/linux/sys_random.go
+++ b/pkg/sentry/syscalls/linux/sys_random.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -64,7 +66,7 @@ func GetRandom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if min > 256 {
min = 256
}
- n, err := t.MemoryManager().CopyOutFrom(t, usermem.AddrRangeSeqOf(ar), safemem.FromIOReader{&randReader{-1, min}}, usermem.IOOpts{
+ n, err := t.MemoryManager().CopyOutFrom(t, hostarch.AddrRangeSeqOf(ar), safemem.FromIOReader{&randReader{-1, min}}, usermem.IOOpts{
AddressSpaceActive: true,
})
if n >= int64(min) {
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
index 88cd234d1..e64246d57 100644
--- a/pkg/sentry/syscalls/linux/sys_rlimit.go
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -16,12 +16,12 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// rlimit describes an implementation of 'struct rlimit', which may vary from
@@ -67,12 +67,12 @@ func (r *rlimit64) fromLimit(lim limits.Limit) {
}
}
-func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error {
+func (r *rlimit64) copyIn(t *kernel.Task, addr hostarch.Addr) error {
_, err := r.CopyIn(t, addr)
return err
}
-func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error {
+func (r *rlimit64) copyOut(t *kernel.Task, addr hostarch.Addr) error {
_, err := r.CopyOut(t, addr)
return err
}
diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go
index 4fdb4463c..e16d6ff3f 100644
--- a/pkg/sentry/syscalls/linux/sys_seccomp.go
+++ b/pkg/sentry/syscalls/linux/sys_seccomp.go
@@ -17,10 +17,10 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// userSockFprog is equivalent to Linux's struct sock_fprog on amd64.
@@ -33,14 +33,14 @@ type userSockFprog struct {
_ [6]byte // padding for alignment
// Filter is a user pointer to the struct sock_filter array that makes up
- // the filter program. Filter is a uint64 rather than a usermem.Addr
- // because usermem.Addr is actually uintptr, which is not a fixed-size
+ // the filter program. Filter is a uint64 rather than a hostarch.Addr
+ // because hostarch.Addr is actually uintptr, which is not a fixed-size
// type.
Filter uint64
}
// seccomp applies a seccomp policy to the current task.
-func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error {
+func seccomp(t *kernel.Task, mode, flags uint64, addr hostarch.Addr) error {
// We only support SECCOMP_SET_MODE_FILTER at the moment.
if mode != linux.SECCOMP_SET_MODE_FILTER {
// Unsupported mode.
@@ -60,7 +60,7 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error {
return err
}
filter := make([]linux.BPFInstruction, int(fprog.Len))
- if _, err := linux.CopyBPFInstructionSliceIn(t, usermem.Addr(fprog.Filter), filter); err != nil {
+ if _, err := linux.CopyBPFInstructionSliceIn(t, hostarch.Addr(fprog.Filter), filter); err != nil {
return err
}
compiledFilter, err := bpf.Compile(filter)
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index f0570d927..c84260080 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -19,13 +19,13 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
const opsMax = 500 // SEMOPM
@@ -310,7 +310,7 @@ func setVal(t *kernel.Task, id int32, num int32, val int16) error {
return set.SetVal(t, num, val, creds, int32(pid))
}
-func setValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+func setValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -335,7 +335,7 @@ func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
return set.GetVal(num, creds)
}
-func getValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+func getValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index d639c9bf7..53b12dc41 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -19,12 +19,12 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// "For a process to have permission to send a signal it must
@@ -516,7 +516,7 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
// sharedSignalfd is shared between the two calls.
-func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
// Copy in the signal mask.
mask, err := CopyInSigSet(t, sigset, sigsetsize)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index c6adfe06b..eff251cec 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -18,6 +18,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -34,12 +35,6 @@ import (
// LINT.IfChange
-// minListenBacklog is the minimum reasonable backlog for listening sockets.
-const minListenBacklog = 8
-
-// maxListenBacklog is the maximum allowed backlog for listening sockets.
-const maxListenBacklog = 1024
-
// maxAddrLen is the maximum socket address length we're willing to accept.
const maxAddrLen = 200
@@ -51,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -117,7 +115,7 @@ type multipleMessageHeader64 struct {
// CaptureAddress allocates memory for and copies a socket address structure
// from the untrusted address space range.
-func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) {
if addrlen > maxAddrLen {
return nil, syserror.EINVAL
}
@@ -133,7 +131,7 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte,
// writeAddress writes a sockaddr structure and its length to an output buffer
// in the unstrusted address space range. If the address is bigger than the
// buffer, it is truncated.
-func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr hostarch.Addr, addrLenPtr hostarch.Addr) error {
// Get the buffer length.
var bufLen uint32
if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil {
@@ -276,7 +274,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// accept is the implementation of the accept syscall. It is called by accept
// and accept4 syscall handlers.
-func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) {
// Check that no unsupported flags are passed in.
if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
return 0, syserror.EINVAL
@@ -366,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFile(fd)
@@ -381,11 +379,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
- // Per Linux, the backlog is silently capped to reasonable values.
- if backlog <= 0 {
- backlog = minListenBacklog
- }
if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
+ //
+ // We use the backlog to allocate a channel of that size, hence enforce
+ // a hard limit for the backlog.
backlog = maxListenBacklog
}
@@ -472,7 +472,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr hostarch.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -735,7 +735,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(count), nil, nil
}
-func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
// Capture the message header and io vectors.
var msg MessageHeader64
if _, err := msg.CopyIn(t, msgPtr); err != nil {
@@ -745,7 +745,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if msg.IovLen > linux.UIO_MAXIOV {
return 0, syserror.EMSGSIZE
}
- dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ dst, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
AddressSpaceActive: true,
})
if err != nil {
@@ -796,7 +796,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
// Copy the address to the caller.
if msg.NameLen != 0 {
- if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ if err := writeAddress(t, sender, senderLen, hostarch.Addr(msg.Name), hostarch.Addr(msgPtr+nameLenOffset)); err != nil {
return 0, err
}
}
@@ -806,7 +806,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
return 0, err
}
if len(controlData) > 0 {
- if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyOutBytes(hostarch.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
@@ -821,7 +821,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
// recvFrom is the implementation of the recvfrom syscall. It is called by
// recvfrom and recv syscall handlers.
-func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) {
if int(bufLen) < 0 {
return 0, syserror.EINVAL
}
@@ -997,7 +997,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(count), nil, nil
}
-func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr hostarch.Addr, flags int32) (uintptr, error) {
// Capture the message header.
var msg MessageHeader64
if _, err := msg.CopyIn(t, msgPtr); err != nil {
@@ -1011,7 +1011,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
return 0, syserror.ENOBUFS
}
controlData = make([]byte, msg.ControlLen)
- if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyInBytes(hostarch.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
@@ -1020,7 +1020,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
var to []byte
if msg.NameLen != 0 {
var err error
- to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ to, err = CaptureAddress(t, hostarch.Addr(msg.Name), msg.NameLen)
if err != nil {
return 0, err
}
@@ -1030,7 +1030,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
if msg.IovLen > linux.UIO_MAXIOV {
return 0, syserror.EMSGSIZE
}
- src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ src, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
AddressSpaceActive: true,
})
if err != nil {
@@ -1064,7 +1064,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
// sendTo is the implementation of the sendto syscall. It is called by sendto
// and send syscall handlers.
-func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) {
bl := int(bufLen)
if bl < 0 {
return 0, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index cda29a8b5..2338ba44b 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -16,11 +16,11 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// LINT.IfChange
@@ -106,7 +106,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
// stat implements stat from the given *fs.Dirent.
-func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) error {
+func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr hostarch.Addr) error {
if dirPath && !fs.IsDir(d.Inode.StableAttr) {
return syserror.ENOTDIR
}
@@ -120,7 +120,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
}
// fstat implements fstat for the given *fs.File.
-func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
+func fstat(t *kernel.Task, f *fs.File, statAddr hostarch.Addr) error {
uattr, err := f.UnstableAttr(t)
if err != nil {
return err
@@ -180,7 +180,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
})
}
-func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr usermem.Addr) error {
+func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr hostarch.Addr) error {
// "[T]he kernel may return fields that weren't requested and may fail to
// return fields that were requested, depending on what the backing
// filesystem supports.
@@ -257,7 +257,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// statfsImpl implements the linux syscall statfs and fstatfs based on a Dirent,
// copying the statfs structure out to addr on success, otherwise an error is
// returned.
-func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
+func statfsImpl(t *kernel.Task, d *fs.Dirent, addr hostarch.Addr) error {
info, err := d.Inode.StatFS(t)
if err != nil {
return err
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index b5f920949..3185ea527 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -19,6 +19,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -46,7 +47,7 @@ var (
ExecMaxTotalSize = 2 * 1024 * 1024
// ExecMaxElemSize is the maximum length of a single argv or envv entry.
- ExecMaxElemSize = 32 * usermem.PageSize
+ ExecMaxElemSize = 32 * hostarch.PageSize
)
// Getppid implements linux syscall getppid(2).
@@ -88,7 +89,7 @@ func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags)
}
-func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr hostarch.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
if err != nil {
return 0, nil, err
@@ -199,7 +200,7 @@ func ExitGroup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
// clone is used by Clone, Fork, and VFork.
-func clone(t *kernel.Task, flags int, stack usermem.Addr, parentTID usermem.Addr, childTID usermem.Addr, tls usermem.Addr) (uintptr, *kernel.SyscallControl, error) {
+func clone(t *kernel.Task, flags int, stack hostarch.Addr, parentTID hostarch.Addr, childTID hostarch.Addr, tls hostarch.Addr) (uintptr, *kernel.SyscallControl, error) {
opts := kernel.CloneOptions{
SharingOptions: kernel.SharingOptions{
NewAddressSpace: flags&linux.CLONE_VM == 0,
@@ -274,7 +275,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error {
}
// wait4 waits for the given child process to exit.
-func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusageAddr usermem.Addr) (uintptr, error) {
+func wait4(t *kernel.Task, pid int, statusAddr hostarch.Addr, options int, rusageAddr hostarch.Addr) (uintptr, error) {
if options&^(linux.WNOHANG|linux.WUNTRACED|linux.WCONTINUED|linux.WNOTHREAD|linux.WALL|linux.WCLONE) != 0 {
return 0, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index c5054d2f1..83b777bbd 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -19,12 +19,12 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// The most significant 29 bits hold either a pid or a file descriptor.
@@ -165,7 +165,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
addr := args[0].Pointer()
r := t.Kernel().RealtimeClock().Now().TimeT()
- if addr == usermem.Addr(0) {
+ if addr == hostarch.Addr(0) {
return uintptr(r), nil, nil
}
@@ -182,7 +182,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
type clockNanosleepRestartBlock struct {
c ktime.Clock
duration time.Duration
- rem usermem.Addr
+ rem hostarch.Addr
}
// Restart implements kernel.SyscallRestartBlock.Restart.
@@ -221,7 +221,7 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error
//
// If blocking is interrupted, the syscall is restarted with the remaining
// duration timeout.
-func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem usermem.Addr) error {
+func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem hostarch.Addr) error {
timer, start, tchan := ktime.After(c, dur)
err := t.BlockWithTimer(nil, tchan)
@@ -324,14 +324,14 @@ func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
tv := args[0].Pointer()
tz := args[1].Pointer()
- if tv != usermem.Addr(0) {
+ if tv != hostarch.Addr(0) {
nowTv := t.Kernel().RealtimeClock().Now().Timeval()
if err := copyTimevalOut(t, tv, &nowTv); err != nil {
return 0, nil, err
}
}
- if tz != usermem.Addr(0) {
+ if tz != hostarch.Addr(0) {
// Ask the time package for the timezone.
_, offset := time.Now().Zone()
// This int32 array mimics linux's struct timezone.
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 97474fd3c..28ad6a60e 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -18,11 +18,11 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// LINT.IfChange
@@ -87,7 +87,7 @@ func getXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink
}
// getXattr implements getxattr(2) from the given *fs.Dirent.
-func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64) (int, error) {
+func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, size uint64) (int, error) {
name, err := copyInXattrName(t, nameAddr)
if err != nil {
return 0, err
@@ -180,7 +180,7 @@ func setXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink
}
// setXattr implements setxattr(2) from the given *fs.Dirent.
-func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64, flags uint32) error {
+func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr hostarch.Addr, size uint64, flags uint32) error {
if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
return syserror.EINVAL
}
@@ -214,7 +214,7 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, si
return nil
}
-func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) {
name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
if err != nil {
if err == syserror.ENAMETOOLONG {
@@ -306,7 +306,7 @@ func listXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlin
return uintptr(n), nil, nil
}
-func listXattr(t *kernel.Task, d *fs.Dirent, addr usermem.Addr, size uint64) (int, error) {
+func listXattr(t *kernel.Task, d *fs.Dirent, addr hostarch.Addr, size uint64) (int, error) {
if !xattrFileTypeOk(d.Inode) {
return 0, nil
}
@@ -408,7 +408,7 @@ func removeXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSyml
}
// removeXattr implements removexattr(2) from the given *fs.Dirent.
-func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
+func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr hostarch.Addr) error {
name, err := copyInXattrName(t, nameAddr)
if err != nil {
return err
diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go
index ddc3ee26e..3edc922eb 100644
--- a/pkg/sentry/syscalls/linux/timespec.go
+++ b/pkg/sentry/syscalls/linux/timespec.go
@@ -18,13 +18,13 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// copyTimespecIn copies a Timespec from the untrusted app range to the kernel.
-func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) {
+func copyTimespecIn(t *kernel.Task, addr hostarch.Addr) (linux.Timespec, error) {
switch t.Arch().Width() {
case 8:
ts := linux.Timespec{}
@@ -33,8 +33,8 @@ func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) {
if err != nil {
return ts, err
}
- ts.Sec = int64(usermem.ByteOrder.Uint64(in[0:]))
- ts.Nsec = int64(usermem.ByteOrder.Uint64(in[8:]))
+ ts.Sec = int64(hostarch.ByteOrder.Uint64(in[0:]))
+ ts.Nsec = int64(hostarch.ByteOrder.Uint64(in[8:]))
return ts, nil
default:
return linux.Timespec{}, syserror.ENOSYS
@@ -42,12 +42,12 @@ func copyTimespecIn(t *kernel.Task, addr usermem.Addr) (linux.Timespec, error) {
}
// copyTimespecOut copies a Timespec to the untrusted app range.
-func copyTimespecOut(t *kernel.Task, addr usermem.Addr, ts *linux.Timespec) error {
+func copyTimespecOut(t *kernel.Task, addr hostarch.Addr, ts *linux.Timespec) error {
switch t.Arch().Width() {
case 8:
out := t.CopyScratchBuffer(16)
- usermem.ByteOrder.PutUint64(out[0:], uint64(ts.Sec))
- usermem.ByteOrder.PutUint64(out[8:], uint64(ts.Nsec))
+ hostarch.ByteOrder.PutUint64(out[0:], uint64(ts.Sec))
+ hostarch.ByteOrder.PutUint64(out[8:], uint64(ts.Nsec))
_, err := t.CopyOutBytes(addr, out)
return err
default:
@@ -56,7 +56,7 @@ func copyTimespecOut(t *kernel.Task, addr usermem.Addr, ts *linux.Timespec) erro
}
// copyTimevalIn copies a Timeval from the untrusted app range to the kernel.
-func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) {
+func copyTimevalIn(t *kernel.Task, addr hostarch.Addr) (linux.Timeval, error) {
switch t.Arch().Width() {
case 8:
tv := linux.Timeval{}
@@ -65,8 +65,8 @@ func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) {
if err != nil {
return tv, err
}
- tv.Sec = int64(usermem.ByteOrder.Uint64(in[0:]))
- tv.Usec = int64(usermem.ByteOrder.Uint64(in[8:]))
+ tv.Sec = int64(hostarch.ByteOrder.Uint64(in[0:]))
+ tv.Usec = int64(hostarch.ByteOrder.Uint64(in[8:]))
return tv, nil
default:
return linux.Timeval{}, syserror.ENOSYS
@@ -74,12 +74,12 @@ func copyTimevalIn(t *kernel.Task, addr usermem.Addr) (linux.Timeval, error) {
}
// copyTimevalOut copies a Timeval to the untrusted app range.
-func copyTimevalOut(t *kernel.Task, addr usermem.Addr, tv *linux.Timeval) error {
+func copyTimevalOut(t *kernel.Task, addr hostarch.Addr, tv *linux.Timeval) error {
switch t.Arch().Width() {
case 8:
out := t.CopyScratchBuffer(16)
- usermem.ByteOrder.PutUint64(out[0:], uint64(tv.Sec))
- usermem.ByteOrder.PutUint64(out[8:], uint64(tv.Usec))
+ hostarch.ByteOrder.PutUint64(out[0:], uint64(tv.Sec))
+ hostarch.ByteOrder.PutUint64(out[8:], uint64(tv.Usec))
_, err := t.CopyOutBytes(addr, out)
return err
default:
@@ -94,7 +94,7 @@ func copyTimevalOut(t *kernel.Task, addr usermem.Addr, tv *linux.Timeval) error
// returned value is the maximum that Duration will allow.
//
// If timespecAddr is NULL, the returned value is negative.
-func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time.Duration, error) {
// Use a negative Duration to indicate "no timeout".
timeout := time.Duration(-1)
if timespecAddr != 0 {
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 2e59bd5b1..5ce0bc714 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -43,6 +43,7 @@ go_library(
"//pkg/context",
"//pkg/fspath",
"//pkg/gohacks",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/marshal",
"//pkg/marshal/primitive",
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index de6789a65..fd1863ef3 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// IoSubmit implements linux syscall io_submit(2).
@@ -40,7 +42,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
for i := int32(0); i < nrEvents; i++ {
// Copy in the callback address.
- var cbAddr usermem.Addr
+ var cbAddr hostarch.Addr
switch t.Arch().Width() {
case 8:
var cbAddrP primitive.Uint64
@@ -52,7 +54,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Nothing done.
return 0, nil, err
}
- cbAddr = usermem.Addr(cbAddrP)
+ cbAddr = hostarch.Addr(cbAddrP)
default:
return 0, nil, syserror.ENOSYS
}
@@ -79,14 +81,14 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Advance to the next one.
- addr += usermem.Addr(t.Arch().Width())
+ addr += hostarch.Addr(t.Arch().Width())
}
return uintptr(nrEvents), nil, nil
}
// submitCallback processes a single callback.
-func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr hostarch.Addr) error {
if cb.Reserved2 != 0 {
return syserror.EINVAL
}
@@ -148,7 +150,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
return nil
}
-func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback {
+func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr hostarch.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback {
return func(ctx context.Context) {
// Release references after completing the callback.
defer fd.DecRef(ctx)
@@ -206,12 +208,12 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error)
// I/O.
switch cb.OpCode {
case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
- return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ return t.SingleIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
- return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ return t.IovecsIOSequence(hostarch.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
index 7a409620d..3315398a4 100644
--- a/pkg/sentry/syscalls/linux/vfs2/execve.go
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -24,7 +24,8 @@ import (
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Execve implements linux syscall execve(2).
@@ -45,7 +46,7 @@ func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
}
-func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr hostarch.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index 01e0f9010..36aa1d3ae 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -20,7 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Link implements Linux syscall link(2).
@@ -40,7 +41,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
}
-func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd int32, newpathAddr hostarch.Addr, flags int32) error {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
return syserror.EINVAL
}
@@ -86,7 +87,7 @@ func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, mkdirat(t, dirfd, addr, mode)
}
-func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+func mkdirat(t *kernel.Task, dirfd int32, addr hostarch.Addr, mode uint) error {
path, err := copyInPath(t, addr)
if err != nil {
return err
@@ -118,7 +119,7 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev)
}
-func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error {
+func mknodat(t *kernel.Task, dirfd int32, addr hostarch.Addr, mode linux.FileMode, dev uint32) error {
path, err := copyInPath(t, addr)
if err != nil {
return err
@@ -165,7 +166,7 @@ func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
}
-func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+func openat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
path, err := copyInPath(t, pathAddr)
if err != nil {
return 0, nil, err
@@ -217,7 +218,7 @@ func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
}
-func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd int32, newpathAddr hostarch.Addr, flags uint32) error {
oldpath, err := copyInPath(t, oldpathAddr)
if err != nil {
return err
@@ -250,7 +251,7 @@ func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
}
-func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr) error {
path, err := copyInPath(t, pathAddr)
if err != nil {
return err
@@ -269,7 +270,7 @@ func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
}
-func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr) error {
path, err := copyInPath(t, pathAddr)
if err != nil {
return err
@@ -313,7 +314,7 @@ func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
}
-func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+func symlinkat(t *kernel.Task, targetAddr hostarch.Addr, newdirfd int32, linkpathAddr hostarch.Addr) error {
target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
if err != nil {
return err
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
index 5517595b5..b41a3056a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/getdents.go
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -22,7 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Getdents implements Linux syscall getdents(2).
@@ -58,7 +59,7 @@ func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (ui
type getdentsCallback struct {
t *kernel.Task
- addr usermem.Addr
+ addr hostarch.Addr
remaining int
isGetdents64 bool
}
@@ -69,7 +70,7 @@ var getdentsCallbackPool = sync.Pool{
},
}
-func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+func getGetdentsCallback(t *kernel.Task, addr hostarch.Addr, size int, isGetdents64 bool) *getdentsCallback {
cb := getdentsCallbackPool.Get().(*getdentsCallback)
*cb = getdentsCallback{
t: t,
@@ -102,9 +103,9 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
return syserror.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
- usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
- usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
- usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ hostarch.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ hostarch.ByteOrder.PutUint16(buf[16:18], uint16(size))
buf[18] = dirent.Type
copy(buf[19:], dirent.Name)
// Zero out all remaining bytes in buf, including the NUL terminator
@@ -136,9 +137,9 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
return syserror.EINVAL
}
buf = cb.t.CopyScratchBuffer(size)
- usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
- usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
- usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ hostarch.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ hostarch.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ hostarch.ByteOrder.PutUint16(buf[16:18], uint16(size))
copy(buf[18:], dirent.Name)
// Zero out all remaining bytes in buf, including the NUL terminator
// after dirent.Name and the zero padding byte between the name and
@@ -155,7 +156,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
// cb.remaining.
return err
}
- cb.addr += usermem.Addr(n)
+ cb.addr += hostarch.Addr(n)
cb.remaining -= n
return nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
index 9d9dbf775..c961545f6 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mmap.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -21,7 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Mmap implements Linux syscall mmap(2).
@@ -48,12 +49,12 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Unmap: fixed,
Map32Bit: map32bit,
Private: private,
- Perms: usermem.AccessType{
+ Perms: hostarch.AccessType{
Read: linux.PROT_READ&prot != 0,
Write: linux.PROT_WRITE&prot != 0,
Execute: linux.PROT_EXEC&prot != 0,
},
- MaxPerms: usermem.AnyAccess,
+ MaxPerms: hostarch.AnyAccess,
GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
Precommit: linux.MAP_POPULATE&flags != 0,
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
index 769c9b92f..dd93430e2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mount.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -20,7 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Mount implements Linux syscall mount(2).
@@ -33,11 +34,11 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// For null-terminated strings related to mount(2), Linux copies in at most
// a page worth of data. See fs/namespace.c:copy_mount_string().
- fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ fsType, err := t.CopyInString(typeAddr, hostarch.PageSize)
if err != nil {
return 0, nil, err
}
- source, err := t.CopyInString(sourceAddr, usermem.PageSize)
+ source, err := t.CopyInString(sourceAddr, hostarch.PageSize)
if err != nil {
return 0, nil, err
}
@@ -53,7 +54,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// character placement, and the address is passed to each file system.
// Most file systems always treat this data as a string, though, and so
// do all of the ones we implement.
- data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ data, err = t.CopyInString(dataAddr, hostarch.PageSize)
if err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
index 90a511d9a..2aaf1ed74 100644
--- a/pkg/sentry/syscalls/linux/vfs2/path.go
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -20,10 +20,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
-func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+func copyInPath(t *kernel.Task, addr hostarch.Addr) (fspath.Path, error) {
pathname, err := t.CopyInString(addr, linux.PATH_MAX)
if err != nil {
return fspath.Path{}, err
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
index 6986e39fe..c6fc1954c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/pipe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -22,7 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Pipe implements Linux syscall pipe(2).
@@ -38,7 +39,7 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, pipe2(t, addr, flags)
}
-func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
+func pipe2(t *kernel.Task, addr hostarch.Addr, flags int32) error {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
return syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index c22e4ce54..a69c80edd 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -25,8 +25,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// fileCap is the maximum allowable files for poll & select. This has no
@@ -158,7 +159,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
}
// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
-func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+func copyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) {
if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
return nil, syserror.EINVAL
}
@@ -173,7 +174,7 @@ func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD
return pfd, nil
}
-func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
pfd, err := copyInPollFDs(t, addr, nfds)
if err != nil {
return timeout, 0, err
@@ -201,7 +202,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
}
// CopyInFDSet copies an fd set from select(2)/pselect(2).
-func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
set := make([]byte, nBytes)
if addr != 0 {
@@ -218,7 +219,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy
return set, nil
}
-func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) {
if nfds < 0 || nfds > fileCap {
return 0, syserror.EINVAL
}
@@ -368,7 +369,7 @@ func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration)
// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
//
// startNs must be from CLOCK_MONOTONIC.
-func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr hostarch.Addr) error {
if timeout <= 0 {
return nil
}
@@ -381,7 +382,7 @@ func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.D
// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
//
// startNs must be from CLOCK_MONOTONIC.
-func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr hostarch.Addr) error {
if timeout <= 0 {
return nil
}
@@ -396,7 +397,7 @@ func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Du
//
// +stateify savable
type pollRestartBlock struct {
- pfdAddr usermem.Addr
+ pfdAddr hostarch.Addr
nfds uint
timeout time.Duration
}
@@ -406,7 +407,7 @@ func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
return poll(t, p.pfdAddr, p.nfds, p.timeout)
}
-func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
// On an interrupt poll(2) is restarted with the remaining timeout.
if err == syserror.EINTR {
@@ -530,7 +531,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if _, err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
return 0, nil, err
}
- if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ if err := setTempSignalSet(t, hostarch.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
return 0, nil, err
}
}
@@ -551,7 +552,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// returned value is the maximum that Duration will allow.
//
// If timespecAddr is NULL, the returned value is negative.
-func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr hostarch.Addr) (time.Duration, error) {
// Use a negative Duration to indicate "no timeout".
timeout := time.Duration(-1)
if timespecAddr != 0 {
@@ -567,7 +568,7 @@ func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.D
return timeout, nil
}
-func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+func setTempSignalSet(t *kernel.Task, maskAddr hostarch.Addr, maskSize uint) error {
if maskAddr == 0 {
return nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 903169dc2..c6330c21a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -23,7 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
@@ -43,7 +44,7 @@ func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
}
-func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, mode uint) error {
path, err := copyInPath(t, pathAddr)
if err != nil {
return err
@@ -102,7 +103,7 @@ func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
}
-func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+func fchownat(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, owner, group, flags int32) error {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
return syserror.EINVAL
}
@@ -327,7 +328,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, followFinalSymlink, &opts)
}
-func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr hostarch.Addr, opts *vfs.SetStatOptions) error {
if timesAddr == 0 {
opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
opts.Stat.Atime.Nsec = linux.UTIME_NOW
@@ -391,7 +392,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
}
-func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr hostarch.Addr, opts *vfs.SetStatOptions) error {
if timesAddr == 0 {
opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
opts.Stat.Atime.Nsec = linux.UTIME_NOW
diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go
index b89f34cdb..6163da103 100644
--- a/pkg/sentry/syscalls/linux/vfs2/signal.go
+++ b/pkg/sentry/syscalls/linux/vfs2/signal.go
@@ -21,11 +21,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// sharedSignalfd is shared between the two calls.
-func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+func sharedSignalfd(t *kernel.Task, fd int32, sigset hostarch.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
// Copy in the signal mask.
mask, err := slinux.CopyInSigSet(t, sigset, sigsetsize)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 346fd1cea..936614eab 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -31,13 +31,9 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
-)
-
-// minListenBacklog is the minimum reasonable backlog for listening sockets.
-const minListenBacklog = 8
-// maxListenBacklog is the maximum allowed backlog for listening sockets.
-const maxListenBacklog = 1024
+ "gvisor.dev/gvisor/pkg/hostarch"
+)
// maxAddrLen is the maximum socket address length we're willing to accept.
const maxAddrLen = 200
@@ -50,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -116,7 +115,7 @@ type multipleMessageHeader64 struct {
// CaptureAddress allocates memory for and copies a socket address structure
// from the untrusted address space range.
-func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) {
if addrlen > maxAddrLen {
return nil, syserror.EINVAL
}
@@ -132,7 +131,7 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte,
// writeAddress writes a sockaddr structure and its length to an output buffer
// in the unstrusted address space range. If the address is bigger than the
// buffer, it is truncated.
-func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr hostarch.Addr, addrLenPtr hostarch.Addr) error {
// Get the buffer length.
var bufLen uint32
if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil {
@@ -279,7 +278,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// accept is the implementation of the accept syscall. It is called by accept
// and accept4 syscall handlers.
-func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) {
// Check that no unsupported flags are passed in.
if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
return 0, syserror.EINVAL
@@ -369,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFileVFS2(fd)
@@ -384,11 +383,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
- // Per Linux, the backlog is silently capped to reasonable values.
- if backlog <= 0 {
- backlog = minListenBacklog
- }
if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
+ //
+ // We use the backlog to allocate a channel of that size, hence enforce
+ // a hard limit for the backlog.
backlog = maxListenBacklog
}
@@ -475,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr hostarch.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -738,7 +739,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(count), nil, nil
}
-func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
// Capture the message header and io vectors.
var msg MessageHeader64
if _, err := msg.CopyIn(t, msgPtr); err != nil {
@@ -748,7 +749,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
if msg.IovLen > linux.UIO_MAXIOV {
return 0, syserror.EMSGSIZE
}
- dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ dst, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
AddressSpaceActive: true,
})
if err != nil {
@@ -799,7 +800,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
// Copy the address to the caller.
if msg.NameLen != 0 {
- if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ if err := writeAddress(t, sender, senderLen, hostarch.Addr(msg.Name), hostarch.Addr(msgPtr+nameLenOffset)); err != nil {
return 0, err
}
}
@@ -809,7 +810,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
return 0, err
}
if len(controlData) > 0 {
- if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyOutBytes(hostarch.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
@@ -824,7 +825,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
// recvFrom is the implementation of the recvfrom syscall. It is called by
// recvfrom and recv syscall handlers.
-func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) {
if int(bufLen) < 0 {
return 0, syserror.EINVAL
}
@@ -1000,7 +1001,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(count), nil, nil
}
-func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr hostarch.Addr, flags int32) (uintptr, error) {
// Capture the message header.
var msg MessageHeader64
if _, err := msg.CopyIn(t, msgPtr); err != nil {
@@ -1014,7 +1015,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
return 0, syserror.ENOBUFS
}
controlData = make([]byte, msg.ControlLen)
- if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyInBytes(hostarch.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
@@ -1023,7 +1024,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
var to []byte
if msg.NameLen != 0 {
var err error
- to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ to, err = CaptureAddress(t, hostarch.Addr(msg.Name), msg.NameLen)
if err != nil {
return 0, err
}
@@ -1033,7 +1034,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
if msg.IovLen > linux.UIO_MAXIOV {
return 0, syserror.EMSGSIZE
}
- src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ src, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
AddressSpaceActive: true,
})
if err != nil {
@@ -1067,7 +1068,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
// sendTo is the implementation of the sendto syscall. It is called by sendto
// and send syscall handlers.
-func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) {
bl := int(bufLen)
if bl < 0 {
return 0, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
index 0f5d5189c..69e77fa99 100644
--- a/pkg/sentry/syscalls/linux/vfs2/stat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -24,7 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Stat implements Linux syscall stat(2).
@@ -50,7 +51,7 @@ func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
}
-func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr hostarch.Addr, flags int32) error {
if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
return syserror.EINVAL
}
@@ -264,7 +265,7 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, accessAt(t, dirfd, addr, mode)
}
-func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+func accessAt(t *kernel.Task, dirfd int32, pathAddr hostarch.Addr, mode uint) error {
const rOK = 4
const wOK = 2
const xOK = 1
@@ -312,7 +313,7 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return readlinkat(t, dirfd, pathAddr, bufAddr, size)
}
-func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr hostarch.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
if int(size) <= 0 {
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
index e05723ef9..c261050c6 100644
--- a/pkg/sentry/syscalls/linux/vfs2/xattr.go
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -23,7 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// ListXattr implements Linux syscall listxattr(2).
@@ -291,7 +292,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
return 0, nil, file.RemoveXattr(t, name)
}
-func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+func copyInXattrName(t *kernel.Task, nameAddr hostarch.Addr) (string, error) {
name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
if err != nil {
if err == syserror.ENAMETOOLONG {
@@ -305,7 +306,7 @@ func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
return name, nil
}
-func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+func copyOutXattrNameList(t *kernel.Task, listAddr hostarch.Addr, size uint, names []string) (int, error) {
if size > linux.XATTR_LIST_MAX {
size = linux.XATTR_LIST_MAX
}
@@ -327,7 +328,7 @@ func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, name
return t.CopyOutBytes(listAddr, buf.Bytes())
}
-func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+func copyInXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint) (string, error) {
if size > linux.XATTR_SIZE_MAX {
return "", syserror.E2BIG
}
@@ -338,7 +339,7 @@ func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string
return gohacks.StringFromImmutableBytes(buf), nil
}
-func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+func copyOutXattrValue(t *kernel.Task, valueAddr hostarch.Addr, size uint, value string) (int, error) {
if size > linux.XATTR_SIZE_MAX {
size = linux.XATTR_SIZE_MAX
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 87d8687ce..1f617ca8f 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -32,6 +32,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/metric",
"//pkg/sync",
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index df4990854..ac60fe8bf 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -99,6 +99,7 @@ go_library(
"//pkg/fdnotifier",
"//pkg/fspath",
"//pkg/gohacks",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 3caf417ca..f48817132 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -20,10 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// NewAnonVirtualDentry returns a VirtualDentry with the given synthetic name,
@@ -43,7 +43,7 @@ func (vfs *VirtualFilesystem) NewAnonVirtualDentry(name string) VirtualDentry {
}
const (
- anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
+ anonfsBlockSize = hostarch.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
// Mode, UID, and GID for a generic anonfs file.
anonFileMode = 0600 // no type is correct
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 1556b41a3..b87d9690a 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -252,6 +252,9 @@ type WritableDynamicBytesSource interface {
// are backed by a bytes.Buffer that is regenerated when necessary, consistent
// with Linux's fs/seq_file.c:single_open().
//
+// If data additionally implements WritableDynamicBytesSource, writes are
+// dispatched to the implementer. The source data is not automatically modified.
+//
// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first
// use.
//
diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go
index 2620cf975..15b234d61 100644
--- a/pkg/sentry/vfs/filesystem_impl_util.go
+++ b/pkg/sentry/vfs/filesystem_impl_util.go
@@ -18,7 +18,7 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// GenericParseMountOptions parses a comma-separated list of options of the
@@ -50,7 +50,7 @@ func GenericParseMountOptions(str string) map[string]string {
func GenericStatFS(fsMagic uint64) linux.Statfs {
return linux.Statfs{
Type: fsMagic,
- BlockSize: usermem.PageSize,
+ BlockSize: hostarch.PageSize,
NameLength: linux.NAME_MAX,
}
}
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 32fa01578..49d29e20b 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sync"
@@ -256,7 +257,7 @@ func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallAr
n += uint32(e.sizeOf())
}
var buf [4]byte
- usermem.ByteOrder.PutUint32(buf[:], n)
+ hostarch.ByteOrder.PutUint32(buf[:], n)
_, err := uio.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{})
return 0, err
@@ -683,10 +684,10 @@ func (e *Event) sizeOf() int {
// construct the output. We use a buffer allocated ahead of time for
// performance. buf must be at least inotifyEventBaseSize bytes.
func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) {
- usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
- usermem.ByteOrder.PutUint32(buf[4:], e.mask)
- usermem.ByteOrder.PutUint32(buf[8:], e.cookie)
- usermem.ByteOrder.PutUint32(buf[12:], e.len)
+ hostarch.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
+ hostarch.ByteOrder.PutUint32(buf[4:], e.mask)
+ hostarch.ByteOrder.PutUint32(buf[8:], e.cookie)
+ hostarch.ByteOrder.PutUint32(buf[12:], e.len)
writeLen := 0
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 922f9e697..7cdab6945 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -970,17 +970,22 @@ func superBlockOpts(mountPath string, mnt *Mount) string {
opts += "," + mopts
}
- // NOTE(b/147673608): If the mount is a cgroup, we also need to include
- // the cgroup name in the options. For now we just read that from the
- // path.
+ // NOTE(b/147673608): If the mount is a ramdisk-based fake cgroupfs, we also
+ // need to include the cgroup name in the options. For now we just read that
+ // from the path. Note that this is only possible when "cgroup" isn't
+ // registered as a valid filesystem type.
//
- // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
- // should get this value from the cgroup itself, and not rely on the
- // path.
+ // TODO(gvisor.dev/issue/190): Once we removed fake cgroupfs support, we
+ // should remove this.
+ if cgroupfs := mnt.vfs.getFilesystemType("cgroup"); cgroupfs != nil && cgroupfs.opts.AllowUserMount {
+ // Real cgroupfs available.
+ return opts
+ }
if mnt.fs.FilesystemType().Name() == "cgroup" {
splitPath := strings.Split(mountPath, "/")
cgroupType := splitPath[len(splitPath)-1]
opts += "," + cgroupType
}
+
return opts
}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index b2c5229e7..8b3a11c64 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -43,6 +43,7 @@ go_template(
],
deps = [
":sync",
+ "//pkg/gohacks",
],
)
diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go
index 82b676abf..9578c9c52 100644
--- a/pkg/sync/generic_seqatomic_unsafe.go
+++ b/pkg/sync/generic_seqatomic_unsafe.go
@@ -10,6 +10,7 @@ package seqatomic
import (
"unsafe"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -39,7 +40,7 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value)
// runtime.RaceDisable() doesn't actually stop the race detector, so it
// can't help us here. Instead, call runtime.memmove directly, which is
// not instrumented by the race detector.
- sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
// This is ~40% faster for short reads than going through memmove.
val = *ptr
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
index 158985709..39c766331 100644
--- a/pkg/sync/runtime_unsafe.go
+++ b/pkg/sync/runtime_unsafe.go
@@ -17,20 +17,6 @@ import (
"unsafe"
)
-// Note that go:linkname silently doesn't work if the local name is exported,
-// necessitating an indirection for exported functions.
-
-// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>.
-//
-//go:nosplit
-func Memmove(to, from unsafe.Pointer, n uintptr) {
- memmove(to, from, n)
-}
-
-//go:linkname memmove runtime.memmove
-//go:noescape
-func memmove(to, from unsafe.Pointer, n uintptr)
-
// Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock);
// if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g)
// is called. unlockf and its callees must be nosplit and norace, since stack
diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD
index 5c38c783e..5f9164117 100644
--- a/pkg/sync/seqatomictest/BUILD
+++ b/pkg/sync/seqatomictest/BUILD
@@ -18,6 +18,7 @@ go_library(
name = "seqatomic",
srcs = ["seqatomic_int_unsafe.go"],
deps = [
+ "//pkg/gohacks",
"//pkg/sync",
],
)
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index f979d22f0..aa30cfc85 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,4 +1,5 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:deps.bzl", "deps_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -33,6 +34,36 @@ go_library(
],
)
+deps_test(
+ name = "netstack_deps_test",
+ allowed = [
+ "@com_github_google_btree//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
+ ],
+ allowed_prefixes = [
+ "//",
+ "@org_golang_x_sys//internal/unsafeheader",
+ ],
+ targets = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/packetsocket",
+ "//pkg/tcpip/link/qdisc/fifo",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ ],
+)
+
go_test(
name = "tcpip_test",
size = "small",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index fef065b05..12c39dfa3 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -53,9 +53,8 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
t.Error("Not a valid IPv4 packet")
}
- xsum := ipv4.CalculateChecksum()
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
+ if !ipv4.IsChecksumValid() {
+ t.Errorf("Bad checksum, got = %d", ipv4.Checksum())
}
for _, f := range checkers {
@@ -400,18 +399,11 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
- // Verify the checksum.
tcp := header.TCP(last.Payload())
- l := uint16(len(tcp))
-
- xsum := header.Checksum([]byte(first.SourceAddress()), 0)
- xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
- xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
- xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
- xsum = header.Checksum(tcp, xsum)
-
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
+ payload := tcp.Payload()
+ payloadChecksum := header.Checksum(payload, 0)
+ if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) {
+ t.Errorf("Bad checksum, got = %d", tcp.Checksum())
}
// Run the transport checkers.
diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go
index 52c22230e..33ff22a7b 100644
--- a/pkg/tcpip/hash/jenkins/jenkins.go
+++ b/pkg/tcpip/hash/jenkins/jenkins.go
@@ -42,26 +42,26 @@ func (s *Sum32) Reset() { *s = 0 }
// Sum32 returns the hash value
func (s *Sum32) Sum32() uint32 {
- hash := *s
+ sCopy := *s
- hash += (hash << 3)
- hash ^= hash >> 11
- hash += hash << 15
+ sCopy += sCopy << 3
+ sCopy ^= sCopy >> 11
+ sCopy += sCopy << 15
- return uint32(hash)
+ return uint32(sCopy)
}
// Write adds more data to the running hash.
//
// It never returns an error.
func (s *Sum32) Write(data []byte) (int, error) {
- hash := *s
+ sCopy := *s
for _, b := range data {
- hash += Sum32(b)
- hash += hash << 10
- hash ^= hash >> 6
+ sCopy += Sum32(b)
+ sCopy += sCopy << 10
+ sCopy ^= sCopy >> 6
}
- *s = hash
+ *s = sCopy
return len(data), nil
}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 0bdc12d53..01240f5d0 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -52,6 +52,7 @@ go_test(
"//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -69,6 +70,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 3bc8b2b21..bf9ccbf1a 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
func TestIsValidUnicastEthernetAddress(t *testing.T) {
@@ -142,7 +143,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
}
func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
- addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a")
+ addr := testutil.MustParse6("ff02:304:506:708:90a:b0c:d0e:f1a")
if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
}
diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go
index b6126d29a..575604928 100644
--- a/pkg/tcpip/header/igmp_test.go
+++ b/pkg/tcpip/header/igmp_test.go
@@ -18,8 +18,8 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
// TestIGMPHeader tests the functions within header.igmp
@@ -46,7 +46,7 @@ func TestIGMPHeader(t *testing.T) {
t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want)
}
- if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want {
+ if got, want := igmpHeader.GroupAddress(), testutil.MustParse4("1.2.3.4"); got != want {
t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want)
}
@@ -71,7 +71,7 @@ func TestIGMPHeader(t *testing.T) {
t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum)
}
- groupAddress := tcpip.Address("\x04\x03\x02\x01")
+ groupAddress := testutil.MustParse4("4.3.2.1")
igmpHeader.SetGroupAddress(groupAddress)
if got := igmpHeader.GroupAddress(); got != groupAddress {
t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress)
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index f588311e0..2be21ec75 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -178,6 +178,26 @@ const (
IPv4FlagDontFragment
)
+// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined
+// by RFC 3927 section 1.
+var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
+// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as
+// defined by RFC 5771 section 4.
+var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPv4EmptySubnet is the empty IPv4 subnet.
var IPv4EmptySubnet = func() tcpip.Subnet {
subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any))
@@ -423,6 +443,44 @@ func (b IPv4) IsValid(pktSize int) bool {
return true
}
+// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4
+// link-local unicast address.
+func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool {
+ return ipv4LinkLocalUnicastSubnet.Contains(addr)
+}
+
+// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4
+// link-local multicast address.
+func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool {
+ return ipv4LinkLocalMulticastSubnet.Contains(addr)
+}
+
+// IsChecksumValid returns true iff the IPv4 header's checksum is valid.
+func (b IPv4) IsChecksumValid() bool {
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ return b.CalculateChecksum() == 0xffff
+}
+
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
// will be 1110 = 0xe0.
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
index 6475cd694..c02fe898b 100644
--- a/pkg/tcpip/header/ipv4_test.go
+++ b/pkg/tcpip/header/ipv4_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -177,3 +178,77 @@ func TestIPv4EncodeOptions(t *testing.T) {
})
}
}
+
+func TestIsV4LinkLocalUnicastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid (lowest)",
+ addr: "\xa9\xfe\x00\x00",
+ expected: true,
+ },
+ {
+ name: "Valid (highest)",
+ addr: "\xa9\xfe\xff\xff",
+ expected: true,
+ },
+ {
+ name: "Invalid (before subnet)",
+ addr: "\xa9\xfd\xff\xff",
+ expected: false,
+ },
+ {
+ name: "Invalid (after subnet)",
+ addr: "\xa9\xff\x00\x00",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV4LinkLocalMulticastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid (lowest)",
+ addr: "\xe0\x00\x00\x00",
+ expected: true,
+ },
+ {
+ name: "Valid (highest)",
+ addr: "\xe0\x00\x00\xff",
+ expected: true,
+ },
+ {
+ name: "Invalid (before subnet)",
+ addr: "\xdf\xff\xff\xff",
+ expected: false,
+ },
+ {
+ name: "Invalid (after subnet)",
+ addr: "\xe0\x00\x01\x00",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index f2403978c..c3a0407ac 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -98,12 +98,27 @@ const (
// The address is ff02::1.
IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- // IPv6AllRoutersMulticastAddress is a link-local multicast group that
- // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local
+ // multicast group that all IPv6 routers MUST join, as per RFC 4291, section
+ // 2.8. Packets destined to this address will reach the router on an
+ // interface.
+ //
+ // The address is ff01::2.
+ IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group
+ // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all routers on a link.
//
// The address is ff02::2.
- IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group
+ // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers in a site.
+ //
+ // The address is ff05::2.
+ IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
// section 5:
@@ -142,11 +157,6 @@ const (
// ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
// within the byte holding the field, as per RFC 4291 section 2.7.
ipv6MulticastAddressScopeMask = 0xF
-
- // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within
- // a multicast IPv6 address that indicates the address has link-local scope,
- // as per RFC 4291 section 2.7.
- ipv6LinkLocalMulticastScope = 2
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -381,25 +391,25 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
return tcpip.Address(lladdrb[:])
}
-// IsV6LinkLocalAddress determines if the provided address is an IPv6
-// link-local address (fe80::/10).
-func IsV6LinkLocalAddress(addr tcpip.Address) bool {
+// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6
+// link-local unicast address, as defined by RFC 4291 section 2.5.6.
+func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool {
if len(addr) != IPv6AddressSize {
return false
}
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
-// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback
-// address.
+// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback
+// address, as defined by RFC 4291 section 2.5.3.
func IsV6LoopbackAddress(addr tcpip.Address) bool {
return addr == IPv6Loopback
}
-// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
-// link-local multicast address.
+// IsV6LinkLocalMulticastAddress returns true iff the provided address is an
+// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
- return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
+ return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope
}
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
@@ -462,7 +472,7 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
case IsV6LinkLocalMulticastAddress(addr):
return LinkLocalScope, nil
- case IsV6LinkLocalAddress(addr):
+ case IsV6LinkLocalUnicastAddress(addr):
return LinkLocalScope, nil
default:
@@ -520,3 +530,46 @@ func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address)
PrefixLen: IIDOffsetInIPv6Address * 8,
}
}
+
+// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by
+// RFC 7346 section 2.
+type IPv6MulticastScope uint8
+
+// The various values for IPv6 multicast scopes, as per RFC 7346 section 2:
+//
+// +------+--------------------------+-------------------------+
+// | scop | NAME | REFERENCE |
+// +------+--------------------------+-------------------------+
+// | 0 | Reserved | [RFC4291], RFC 7346 |
+// | 1 | Interface-Local scope | [RFC4291], RFC 7346 |
+// | 2 | Link-Local scope | [RFC4291], RFC 7346 |
+// | 3 | Realm-Local scope | [RFC4291], RFC 7346 |
+// | 4 | Admin-Local scope | [RFC4291], RFC 7346 |
+// | 5 | Site-Local scope | [RFC4291], RFC 7346 |
+// | 6 | Unassigned | |
+// | 7 | Unassigned | |
+// | 8 | Organization-Local scope | [RFC4291], RFC 7346 |
+// | 9 | Unassigned | |
+// | A | Unassigned | |
+// | B | Unassigned | |
+// | C | Unassigned | |
+// | D | Unassigned | |
+// | E | Global scope | [RFC4291], RFC 7346 |
+// | F | Reserved | [RFC4291], RFC 7346 |
+// +------+--------------------------+-------------------------+
+const (
+ IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0)
+ IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1)
+ IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2)
+ IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3)
+ IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4)
+ IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5)
+ IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8)
+ IPv6GlobalMulticastScope = IPv6MulticastScope(0xE)
+ IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF)
+)
+
+// V6MulticastScope returns the scope of a multicast address.
+func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope {
+ return IPv6MulticastScope(addr[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask)
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index f10f446a6..89be84068 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -24,15 +24,17 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
-const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+var (
+ linkLocalAddr = testutil.MustParse6("fe80::1")
+ linkLocalMulticastAddr = testutil.MustParse6("ff02::1")
+ uniqueLocalAddr1 = testutil.MustParse6("fc00::1")
+ uniqueLocalAddr2 = testutil.MustParse6("fd00::2")
+ globalAddr = testutil.MustParse6("a000::1")
)
func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
@@ -50,7 +52,7 @@ func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
}
func TestLinkLocalAddr(t *testing.T) {
- if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want {
+ if got, want := header.LinkLocalAddr(linkAddr), testutil.MustParse6("fe80::2:3ff:fe04:506"); got != want {
t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
}
}
@@ -252,7 +254,7 @@ func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
}
}
-func TestIsV6LinkLocalAddress(t *testing.T) {
+func TestIsV6LinkLocalUnicastAddress(t *testing.T) {
tests := []struct {
name string
addr tcpip.Address
@@ -287,8 +289,8 @@ func TestIsV6LinkLocalAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
}
})
}
@@ -373,3 +375,83 @@ func TestSolicitedNodeAddr(t *testing.T) {
})
}
}
+
+func TestV6MulticastScope(t *testing.T) {
+ tests := []struct {
+ addr tcpip.Address
+ want header.IPv6MulticastScope
+ }{
+ {
+ addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6Reserved0MulticastScope,
+ },
+ {
+ addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6InterfaceLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6LinkLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6RealmLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6AdminLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6SiteLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(6),
+ },
+ {
+ addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(7),
+ },
+ {
+ addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6OrganizationLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(9),
+ },
+ {
+ addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(10),
+ },
+ {
+ addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(11),
+ },
+ {
+ addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(12),
+ },
+ {
+ addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(13),
+ },
+ {
+ addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6GlobalMulticastScope,
+ },
+ {
+ addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6ReservedFMulticastScope,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
+ if got := header.V6MulticastScope(test.addr); got != test.want {
+ t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index d0a1a2492..1b5093e58 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -26,6 +26,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit.
@@ -40,13 +41,13 @@ func TestNDPNeighborSolicit(t *testing.T) {
// Test getting the Target Address.
ns := NDPNeighborSolicit(b)
- addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10")
if got := ns.TargetAddress(); got != addr {
t.Errorf("got ns.TargetAddress = %s, want %s", got, addr)
}
// Test updating the Target Address.
- addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11")
ns.SetTargetAddress(addr2)
if got := ns.TargetAddress(); got != addr2 {
t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2)
@@ -69,7 +70,7 @@ func TestNDPNeighborAdvert(t *testing.T) {
// Test getting the Target Address.
na := NDPNeighborAdvert(b)
- addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10")
if got := na.TargetAddress(); got != addr {
t.Errorf("got TargetAddress = %s, want %s", got, addr)
}
@@ -90,7 +91,7 @@ func TestNDPNeighborAdvert(t *testing.T) {
}
// Test updating the Target Address.
- addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11")
na.SetTargetAddress(addr2)
if got := na.TargetAddress(); got != addr2 {
t.Errorf("got TargetAddress = %s, want %s", got, addr2)
@@ -277,7 +278,7 @@ func TestOpts(t *testing.T) {
}
const validLifetimeSeconds = 16909060
- const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18")
+ address := testutil.MustParse6("90a:b0c:d0e:f10:1112:1314:1516:1718")
expectedRDNSSBytes := [...]byte{
// Type, Length
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index adc835d30..0df517000 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -216,104 +216,104 @@ const (
TCPDefaultMSS = 536
)
-// SourcePort returns the "source port" field of the tcp header.
+// SourcePort returns the "source port" field of the TCP header.
func (b TCP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
}
-// DestinationPort returns the "destination port" field of the tcp header.
+// DestinationPort returns the "destination port" field of the TCP header.
func (b TCP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
}
-// SequenceNumber returns the "sequence number" field of the tcp header.
+// SequenceNumber returns the "sequence number" field of the TCP header.
func (b TCP) SequenceNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
}
-// AckNumber returns the "ack number" field of the tcp header.
+// AckNumber returns the "ack number" field of the TCP header.
func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
}
-// DataOffset returns the "data offset" field of the tcp header. The return
+// DataOffset returns the "data offset" field of the TCP header. The return
// value is the length of the TCP header in bytes.
func (b TCP) DataOffset() uint8 {
return (b[TCPDataOffset] >> 4) * 4
}
-// Payload returns the data in the tcp packet.
+// Payload returns the data in the TCP packet.
func (b TCP) Payload() []byte {
return b[b.DataOffset():]
}
-// Flags returns the flags field of the tcp header.
+// Flags returns the flags field of the TCP header.
func (b TCP) Flags() TCPFlags {
return TCPFlags(b[TCPFlagsOffset])
}
-// WindowSize returns the "window size" field of the tcp header.
+// WindowSize returns the "window size" field of the TCP header.
func (b TCP) WindowSize() uint16 {
return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
}
-// Checksum returns the "checksum" field of the tcp header.
+// Checksum returns the "checksum" field of the TCP header.
func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
}
-// UrgentPointer returns the "urgent pointer" field of the tcp header.
+// UrgentPointer returns the "urgent pointer" field of the TCP header.
func (b TCP) UrgentPointer() uint16 {
return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
}
-// SetSourcePort sets the "source port" field of the tcp header.
+// SetSourcePort sets the "source port" field of the TCP header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
}
-// SetDestinationPort sets the "destination port" field of the tcp header.
+// SetDestinationPort sets the "destination port" field of the TCP header.
func (b TCP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
}
-// SetChecksum sets the checksum field of the tcp header.
+// SetChecksum sets the checksum field of the TCP header.
func (b TCP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
}
-// SetDataOffset sets the data offset field of the tcp header. headerLen should
+// SetDataOffset sets the data offset field of the TCP header. headerLen should
// be the length of the TCP header in bytes.
func (b TCP) SetDataOffset(headerLen uint8) {
b[TCPDataOffset] = (headerLen / 4) << 4
}
-// SetSequenceNumber sets the sequence number field of the tcp header.
+// SetSequenceNumber sets the sequence number field of the TCP header.
func (b TCP) SetSequenceNumber(seqNum uint32) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
}
-// SetAckNumber sets the ack number field of the tcp header.
+// SetAckNumber sets the ack number field of the TCP header.
func (b TCP) SetAckNumber(ackNum uint32) {
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
}
-// SetFlags sets the flags field of the tcp header.
+// SetFlags sets the flags field of the TCP header.
func (b TCP) SetFlags(flags uint8) {
b[TCPFlagsOffset] = flags
}
-// SetWindowSize sets the window size field of the tcp header.
+// SetWindowSize sets the window size field of the TCP header.
func (b TCP) SetWindowSize(rcvwnd uint16) {
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
-// SetUrgentPoiner sets the window size field of the tcp header.
+// SetUrgentPoiner sets the window size field of the TCP header.
func (b TCP) SetUrgentPoiner(urgentPointer uint16) {
binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
}
-// CalculateChecksum calculates the checksum of the tcp segment.
+// CalculateChecksum calculates the checksum of the TCP segment.
// partialChecksum is the checksum of the network-layer pseudo-header
// and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
@@ -321,6 +321,13 @@ func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
return Checksum(b[:b.DataOffset()], partialChecksum)
}
+// IsChecksumValid returns true iff the TCP header's checksum is valid.
+func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
+ xsum := PseudoHeaderChecksum(TCPProtocolNumber, src, dst, uint16(b.DataOffset())+payloadLength)
+ xsum = ChecksumCombine(xsum, payloadChecksum)
+ return b.CalculateChecksum(xsum) == 0xffff
+}
+
// Options returns a slice that holds the unparsed TCP options in the segment.
func (b TCP) Options() []byte {
return b[TCPMinimumSize:b.DataOffset()]
@@ -340,7 +347,7 @@ func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) {
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
-// Encode encodes all the fields of the tcp header.
+// Encode encodes all the fields of the TCP header.
func (b TCP) Encode(t *TCPFields) {
b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort)
@@ -350,7 +357,7 @@ func (b TCP) Encode(t *TCPFields) {
binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer)
}
-// EncodePartial updates a subset of the fields of the tcp header. It is useful
+// EncodePartial updates a subset of the fields of the TCP header. It is useful
// in cases when similar segments are produced.
func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) {
// Add the total length and "flags" field contributions to the checksum.
@@ -374,7 +381,7 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
}
// ParseSynOptions parses the options received in a SYN segment and returns the
-// relevant ones. opts should point to the option part of the TCP Header.
+// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
limit := len(opts)
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 98bdd29db..ae9d167ff 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -64,17 +64,17 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
-// SourcePort returns the "source port" field of the udp header.
+// SourcePort returns the "source port" field of the UDP header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
-// DestinationPort returns the "destination port" field of the udp header.
+// DestinationPort returns the "destination port" field of the UDP header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
-// Length returns the "length" field of the udp header.
+// Length returns the "length" field of the UDP header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
@@ -84,39 +84,46 @@ func (b UDP) Payload() []byte {
return b[UDPMinimumSize:]
}
-// Checksum returns the "checksum" field of the udp header.
+// Checksum returns the "checksum" field of the UDP header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
-// SetSourcePort sets the "source port" field of the udp header.
+// SetSourcePort sets the "source port" field of the UDP header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
-// SetDestinationPort sets the "destination port" field of the udp header.
+// SetDestinationPort sets the "destination port" field of the UDP header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
-// SetChecksum sets the "checksum" field of the udp header.
+// SetChecksum sets the "checksum" field of the UDP header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// SetLength sets the "length" field of the udp header.
+// SetLength sets the "length" field of the UDP header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
-// CalculateChecksum calculates the checksum of the udp packet, given the
+// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return Checksum(b[:UDPMinimumSize], partialChecksum)
}
-// Encode encodes all the fields of the udp header.
+// IsChecksumValid returns true iff the UDP header's checksum is valid.
+func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
+ xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length())
+ xsum = ChecksumCombine(xsum, payloadChecksum)
+ return b.CalculateChecksum(xsum) == 0xffff
+}
+
+// Encode encodes all the fields of the UDP header.
func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index fa8814bac..7b1ff44f4 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -21,6 +21,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index d59d678b2..6905b9ccb 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -33,6 +33,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 018d6a578..9b3714f9e 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -30,20 +30,16 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
const (
nicID = 1
- stackAddr = tcpip.Address("\x0a\x00\x00\x01")
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
-
- remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
- unknownAddr = tcpip.Address("\x0a\x00\x00\x03")
-
defaultChannelSize = 1
defaultMTU = 65536
@@ -54,6 +50,12 @@ const (
eventChanSize = 32
)
+var (
+ stackAddr = testutil.MustParse4("10.0.0.1")
+ remoteAddr = testutil.MustParse4("10.0.0.2")
+ unknownAddr = testutil.MustParse4("10.0.0.3")
+)
+
type eventType uint8
const (
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index b9f129728..ac35d81e7 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -156,14 +156,6 @@ type GenericMulticastProtocolOptions struct {
//
// Unsolicited reports are transmitted when a group is newly joined.
MaxUnsolicitedReportDelay time.Duration
-
- // AllNodesAddress is a multicast address that all nodes on a network should
- // be a member of.
- //
- // This address will not have the generic multicast protocol performed on it;
- // it will be left in the non member/listener state, and packets will never
- // be sent for it.
- AllNodesAddress tcpip.Address
}
// MulticastGroupProtocol is a multicast group protocol whose core state machine
@@ -188,6 +180,10 @@ type MulticastGroupProtocol interface {
// SendLeave sends a multicast leave for the specified group address.
SendLeave(groupAddress tcpip.Address) tcpip.Error
+
+ // ShouldPerformProtocol returns true iff the protocol should be performed for
+ // the specified group.
+ ShouldPerformProtocol(tcpip.Address) bool
}
// GenericMulticastProtocolState is the per interface generic multicast protocol
@@ -455,20 +451,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
info.lastToSendReport = false
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
info.state = idleMember
return
}
@@ -537,20 +520,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres
return
}
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
return
}
@@ -627,20 +597,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
return
}
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
return
}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 381460c82..0b51563cd 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -43,6 +43,8 @@ type mockMulticastGroupProtocolProtectedFields struct {
type mockMulticastGroupProtocol struct {
t *testing.T
+ skipProtocolAddress tcpip.Address
+
mu mockMulticastGroupProtocolProtectedFields
}
@@ -165,6 +167,11 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip
return nil
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ return groupAddress != m.skipProtocolAddress
+}
+
func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
m.mu.Lock()
defer m.mu.Unlock()
@@ -193,10 +200,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr
cmp.FilterPath(
func(p cmp.Path) bool {
switch p.Last().String() {
- case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress":
return true
+ default:
+ return false
}
- return false
},
cmp.Ignore(),
),
@@ -225,14 +233,13 @@ func TestJoinGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(0)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr2,
})
// Joining a group should send a report immediately and another after
@@ -279,14 +286,13 @@ func TestLeaveGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(1)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr2,
})
mgp.joinGroup(test.addr)
@@ -356,14 +362,13 @@ func TestHandleReport(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(2)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
@@ -446,14 +451,13 @@ func TestHandleQuery(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
@@ -574,14 +578,13 @@ func TestJoinCount(t *testing.T) {
}
func TestMakeAllNonMemberAndInitialize(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index b6f39ddb1..d06b26309 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -21,53 +21,56 @@ import "gvisor.dev/gvisor/pkg/tcpip"
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
type MultiCounterIPStats struct {
- // PacketsReceived is the total number of IP packets received from the link
- // layer.
+ // PacketsReceived is the number of IP packets received from the link layer.
PacketsReceived tcpip.MultiCounterStat
- // DisabledPacketsReceived is the total number of IP packets received from the
- // link layer when the IP layer is disabled.
+ // DisabledPacketsReceived is the number of IP packets received from the link
+ // layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
- // InvalidDestinationAddressesReceived is the total number of IP packets
- // received with an unknown or invalid destination address.
+ // InvalidDestinationAddressesReceived is the number of IP packets received
+ // with an unknown or invalid destination address.
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
- // InvalidSourceAddressesReceived is the total number of IP packets received
- // with a source address that should never have been received on the wire.
+ // InvalidSourceAddressesReceived is the number of IP packets received with a
+ // source address that should never have been received on the wire.
InvalidSourceAddressesReceived tcpip.MultiCounterStat
- // PacketsDelivered is the total number of incoming IP packets that are
- // successfully delivered to the transport layer.
+ // PacketsDelivered is the number of incoming IP packets that are successfully
+ // delivered to the transport layer.
PacketsDelivered tcpip.MultiCounterStat
- // PacketsSent is the total number of IP packets sent via WritePacket.
+ // PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent tcpip.MultiCounterStat
- // OutgoingPacketErrors is the total number of IP packets which failed to
- // write to a link-layer endpoint.
+ // OutgoingPacketErrors is the number of IP packets which failed to write to a
+ // link-layer endpoint.
OutgoingPacketErrors tcpip.MultiCounterStat
- // MalformedPacketsReceived is the total number of IP Packets that were
- // dropped due to the IP packet header failing validation checks.
+ // MalformedPacketsReceived is the number of IP Packets that were dropped due
+ // to the IP packet header failing validation checks.
MalformedPacketsReceived tcpip.MultiCounterStat
- // MalformedFragmentsReceived is the total number of IP Fragments that were
- // dropped due to the fragment failing validation checks.
+ // MalformedFragmentsReceived is the number of IP Fragments that were dropped
+ // due to the fragment failing validation checks.
MalformedFragmentsReceived tcpip.MultiCounterStat
- // IPTablesPreroutingDropped is the total number of IP packets dropped in the
+ // IPTablesPreroutingDropped is the number of IP packets dropped in the
// Prerouting chain.
IPTablesPreroutingDropped tcpip.MultiCounterStat
- // IPTablesInputDropped is the total number of IP packets dropped in the Input
+ // IPTablesInputDropped is the number of IP packets dropped in the Input
// chain.
IPTablesInputDropped tcpip.MultiCounterStat
- // IPTablesOutputDropped is the total number of IP packets dropped in the
- // Output chain.
+ // IPTablesOutputDropped is the number of IP packets dropped in the Output
+ // chain.
IPTablesOutputDropped tcpip.MultiCounterStat
+ // IPTablesPostroutingDropped is the number of IP packets dropped in the
+ // Postrouting chain.
+ IPTablesPostroutingDropped tcpip.MultiCounterStat
+
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
// of IPStats.
@@ -98,6 +101,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
+ m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a4edc69c7..dbd674634 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "fmt"
"strings"
"testing"
@@ -29,23 +30,25 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
-const (
- localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
- remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02")
- ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00")
- ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00")
- ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03")
- localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
- ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00")
- ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- nicID = 1
+const nicID = 1
+
+var (
+ localIPv4Addr = testutil.MustParse4("10.0.0.1")
+ remoteIPv4Addr = testutil.MustParse4("10.0.0.2")
+ ipv4SubnetAddr = testutil.MustParse4("10.0.0.0")
+ ipv4SubnetMask = testutil.MustParse4("255.255.255.0")
+ ipv4Gateway = testutil.MustParse4("10.0.0.3")
+ localIPv6Addr = testutil.MustParse6("a00::1")
+ remoteIPv6Addr = testutil.MustParse6("a00::2")
+ ipv6SubnetAddr = testutil.MustParse6("a00::")
+ ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00")
+ ipv6Gateway = testutil.MustParse6("a00::3")
)
var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
@@ -1938,3 +1941,80 @@ func TestICMPInclusionSize(t *testing.T) {
})
}
}
+
+func TestJoinLeaveAllRoutersGroup(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ protoFactory stack.NetworkProtocolFactory
+ allRoutersAddr tcpip.Address
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ protoFactory: ipv4.NewProtocol,
+ allRoutersAddr: header.IPv4AllRoutersGroup,
+ },
+ {
+ name: "IPv6 Interface Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress,
+ },
+ {
+ name: "IPv6 Link Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress,
+ },
+ {
+ name: "IPv6 Site Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, nicDisabled := range [...]bool{true, false} {
+ t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+ opts := stack.NICOptions{Disabled: nicDisabled}
+ if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
+ }
+
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
+ }
+
+ if err := s.SetForwarding(test.netProto, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err)
+ }
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if !got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
+ }
+
+ if err := s.SetForwarding(test.netProto, false); err != nil {
+ t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err)
+ }
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 5e7f10f4b..7ee0495d9 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -45,6 +45,7 @@ go_test(
"//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index f3fc1c87e..b1ac29294 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -126,6 +126,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
return err
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ // As per RFC 2236 section 6 page 10,
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ return groupAddress != header.IPv4AllSystems
+}
+
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
//
@@ -137,7 +148,6 @@ func (igmp *igmpState) init(ep *endpoint) {
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
- AllNodesAddress: header.IPv4AllSystems,
})
igmp.igmpV1Present = igmpV1PresentDefault
igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index e5e1b89cc..4bd6f462e 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -26,18 +26,22 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- stackAddr = tcpip.Address("\x0a\x00\x00\x01")
- remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
- multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
nicID = 1
defaultTTL = 1
defaultPrefixLength = 24
)
+var (
+ stackAddr = testutil.MustParse4("10.0.0.1")
+ remoteAddr = testutil.MustParse4("10.0.0.2")
+ multicastAddr = testutil.MustParse4("224.0.0.3")
+)
+
// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
// sent to the provided address with the passed fields set. Raises a t.Error if
// any field does not match.
@@ -292,7 +296,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPLeaveGroup,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: false,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
@@ -302,7 +306,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPMembershipQuery,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: true,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() },
@@ -312,7 +316,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPv1MembershipReport,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: false,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
@@ -322,7 +326,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPv2MembershipReport,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: tcpip.Address("\x0a\x00\x01\x02"),
+ srcAddr: testutil.MustParse4("10.0.1.2"),
ttl: 1,
expectValidIGMP: false,
getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
@@ -332,7 +336,7 @@ func TestIGMPPacketValidation(t *testing.T) {
messageType: header.IGMPv2MembershipReport,
includeRouterAlertOption: true,
stackAddresses: []tcpip.AddressWithPrefix{
- {Address: tcpip.Address("\x0a\x00\x0f\x01"), PrefixLen: 24},
+ {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24},
{Address: stackAddr, PrefixLen: 24},
},
srcAddr: remoteAddr,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1a5661ca4..a82a5790d 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -150,6 +150,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if forwarding {
+ // There does not seem to be an RFC requirement for a node to join the all
+ // routers multicast address but
+ // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml
+ // specifies the address as a group for all routers on a subnet so we join
+ // the group here.
+ if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a
+ // valid IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err))
+ }
+
+ return
+ }
+
+ switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err))
+ }
+}
+
// Enable implements stack.NetworkEndpoint.
func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
@@ -226,7 +258,7 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) {
+ switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
@@ -383,6 +415,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ // Postrouting NAT can only change the source address, and does not alter the
+ // route or outgoing interface of the packet.
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesPostroutingDropped.Increment()
+ return nil
+ }
+
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
@@ -454,9 +495,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
- stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- for pkt := range dropped {
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
+ for pkt := range outputDropped {
pkts.Remove(pkt)
}
@@ -478,6 +519,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
+ // We ignore the list of NAT-ed packets here because Postrouting NAT can only
+ // change the source address, and does not alter the route or outgoing
+ // interface of the packet.
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
+ for pkt := range postroutingDropped {
+ pkts.Remove(pkt)
+ }
+
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -485,7 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
- return locallyDelivered + written + len(dropped), err
+ return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
@@ -551,6 +601,22 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// forwardPacket attempts to forward a packet to its final destination.
func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv4(pkt.NetworkHeader().View())
+
+ dstAddr := h.DestinationAddress()
+ if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) {
+ // As per RFC 3927 section 7,
+ //
+ // A router MUST NOT forward a packet with an IPv4 Link-Local source or
+ // destination address, irrespective of the router's default route
+ // configuration or routes obtained from dynamic routing protocols.
+ //
+ // A router which receives a packet with an IPv4 Link-Local source or
+ // destination address MUST NOT forward the packet. This prevents
+ // forwarding of packets back onto the network segment from which they
+ // originated, or to any other segment.
+ return nil
+ }
+
ttl := h.TTL()
if ttl == 0 {
// As per RFC 792 page 6, Time Exceeded Message,
@@ -589,8 +655,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
}
}
- dstAddr := h.DestinationAddress()
-
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
ep.handleValidatedPacket(h, pkt)
@@ -1114,28 +1178,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool)
return nil, false
}
- // There has been some confusion regarding verifying checksums. We need
- // just look for negative 0 (0xffff) as the checksum, as it's not possible to
- // get positive 0 (0) for the checksum. Some bad implementations could get it
- // when doing entry replacement in the early days of the Internet,
- // however the lore that one needs to check for both persists.
- //
- // RFC 1624 section 1 describes the source of this confusion as:
- // [the partial recalculation method described in RFC 1071] computes a
- // result for certain cases that differs from the one obtained from
- // scratch (one's complement of one's complement sum of the original
- // fields).
- //
- // However RFC 1624 section 5 clarifies that if using the verification method
- // "recommended by RFC 1071, it does not matter if an intermediate system
- // generated a -0 instead of +0".
- //
- // RFC1071 page 1 specifies the verification method as:
- // (3) To check a checksum, the 1's complement sum is computed over the
- // same set of octets, including the checksum field. If the result
- // is all 1 bits (-0 in 1's complement arithmetic), the check
- // succeeds.
- if h.CalculateChecksum() != 0xffff {
+ if !h.IsChecksumValid() {
return nil, false
}
@@ -1168,12 +1211,27 @@ func (p *protocol) Forwarding() bool {
return uint8(atomic.LoadUint32(&p.forwarding)) == 1
}
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
+ }
+ return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
+}
+
// SetForwarding implements stack.ForwardingNetworkProtocol.
func (p *protocol) SetForwarding(v bool) {
- if v {
- atomic.StoreUint32(&p.forwarding, 1)
- } else {
- atomic.StoreUint32(&p.forwarding, 0)
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if !p.setForwarding(v) {
+ return
+ }
+
+ for _, ep := range p.mu.eps {
+ ep.transitionForwarding(v)
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index eba91c68c..d49dff4d5 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -39,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -2612,34 +2613,36 @@ func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectDropped int
- expectWritten int
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectOutputDropped int
+ expectPostroutingDropped int
+ expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectDropped: 0,
- expectWritten: nPackets,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectDropped: 0,
- expectWritten: nPackets - 1,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets - 1,
}, {
- name: "Drop all",
+ name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@@ -2648,16 +2651,32 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectDropped: nPackets,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: nPackets,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
- name: "Drop some",
+ name: "Drop all with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
// We'll match and DROP the last packet.
@@ -2670,10 +2689,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectDropped: 1,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 1,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Postrouting DROP rule that matches only 1
+ // of the 3 packets.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 1,
+ expectWritten: nPackets,
},
}
@@ -2724,13 +2766,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
}
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
- t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
- t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}
@@ -2995,12 +3040,14 @@ func TestCloseLocking(t *testing.T) {
nicID1 = 1
nicID2 = 2
- src = tcpip.Address("\x10\x00\x00\x01")
- dst = tcpip.Address("\x10\x00\x00\x02")
-
iterations = 1000
)
+ var (
+ src = tcptestutil.MustParse4("16.0.0.1")
+ dst = tcptestutil.MustParse4("16.0.0.2")
+ )
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index bb9a02ed0..db998e83e 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -66,5 +66,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index a142b76c1..b2a80e1e9 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -273,7 +273,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
if iph.HopLimit() != header.MLDHopLimit {
return false
}
- if !header.IsV6LinkLocalAddress(iph.SourceAddress()) {
+ if !header.IsV6LinkLocalUnicastAddress(iph.SourceAddress()) {
return false
}
return true
@@ -804,7 +804,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
routerAddr := srcAddr
// Is the IP Source Address a link-local address?
- if !header.IsV6LinkLocalAddress(routerAddr) {
+ if !header.IsV6LinkLocalUnicastAddress(routerAddr) {
// ...No, silently drop the packet.
received.invalid.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c6d9d8f0d..2e515379c 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -314,7 +314,7 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
// Snooping switches MUST manage multicast forwarding state based on MLD
// Report and Done messages sent with the unspecified address as the
// IPv6 source address.
- if header.IsV6LinkLocalAddress(addr) {
+ if header.IsV6LinkLocalUnicastAddress(addr) {
e.mu.mld.sendQueuedReports()
}
}
@@ -410,22 +410,65 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
//
// Must only be called when the forwarding status changes.
func (e *endpoint) transitionForwarding(forwarding bool) {
+ allRoutersGroups := [...]tcpip.Address{
+ header.IPv6AllRoutersInterfaceLocalMulticastAddress,
+ header.IPv6AllRoutersLinkLocalMulticastAddress,
+ header.IPv6AllRoutersSiteLocalMulticastAddress,
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
- if !e.Enabled() {
- return
- }
-
if forwarding {
// When transitioning into an IPv6 router, host-only state (NDP discovered
// routers, discovered on-link prefixes, and auto-generated addresses) is
// cleaned up/invalidated and NDP router solicitations are stopped.
e.mu.ndp.stopSolicitingRouters()
e.mu.ndp.cleanupState(true /* hostOnly */)
- } else {
- // When transitioning into an IPv6 host, NDP router solicitations are
- // started.
+
+ // As per RFC 4291 section 2.8:
+ //
+ // A router is required to recognize all addresses that a host is
+ // required to recognize, plus the following addresses as identifying
+ // itself:
+ //
+ // o The All-Routers multicast addresses defined in Section 2.7.1.
+ //
+ // As per RFC 4291 section 2.7.1,
+ //
+ // All Routers Addresses: FF01:0:0:0:0:0:0:2
+ // FF02:0:0:0:0:0:0:2
+ // FF05:0:0:0:0:0:0:2
+ //
+ // The above multicast addresses identify the group of all IPv6 routers,
+ // within scope 1 (interface-local), 2 (link-local), or 5 (site-local).
+ for _, g := range allRoutersGroups {
+ if err := e.joinGroupLocked(g); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a
+ // valid IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err))
+ }
+ }
+
+ return
+ }
+
+ for _, g := range allRoutersGroups {
+ switch err := e.leaveGroupLocked(g).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err))
+ }
+ }
+
+ // When transitioning into an IPv6 host, NDP router solicitations are
+ // started if the endpoint is enabled.
+ //
+ // If the endpoint is not currently enabled, routers will be solicited when
+ // the endpoint becomes enabled (if it is still a host).
+ if e.Enabled() {
e.mu.ndp.startSolicitingRouters()
}
}
@@ -573,7 +616,7 @@ func (e *endpoint) disableLocked() {
e.mu.ndp.cleanupState(false /* hostOnly */)
// The endpoint may have already left the multicast group.
- switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
+ switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
@@ -726,6 +769,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ // Postrouting NAT can only change the source address, and does not alter the
+ // route or outgoing interface of the packet.
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesPostroutingDropped.Increment()
+ return nil
+ }
+
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
@@ -797,9 +849,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
- stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- for pkt := range dropped {
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
+ for pkt := range outputDropped {
pkts.Remove(pkt)
}
@@ -820,6 +872,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
locallyDelivered++
}
+ // We ignore the list of NAT-ed packets here because Postrouting NAT can only
+ // change the source address, and does not alter the route or outgoing
+ // interface of the packet.
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
+ for pkt := range postroutingDropped {
+ pkts.Remove(pkt)
+ }
+
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -827,7 +888,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
- return locallyDelivered + written + len(dropped), err
+ return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
@@ -869,6 +930,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// forwardPacket attempts to forward a packet to its final destination.
func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv6(pkt.NetworkHeader().View())
+
+ dstAddr := h.DestinationAddress()
+ if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) {
+ // As per RFC 4291 section 2.5.6,
+ //
+ // Routers must not forward any packets with Link-Local source or
+ // destination addresses to other links.
+ return nil
+ }
+
hopLimit := h.HopLimit()
if hopLimit <= 1 {
// As per RFC 4443 section 3.3,
@@ -881,8 +952,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
}
- dstAddr := h.DestinationAddress()
-
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
ep.handleValidatedPacket(h, pkt)
@@ -1571,7 +1640,7 @@ func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
var linkLocalAddr tcpip.Address
e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.IsAssigned(false /* allowExpired */) {
- if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalUnicastAddress(addr) {
linkLocalAddr = addr
return false
}
@@ -1979,9 +2048,9 @@ func (p *protocol) Forwarding() bool {
// Returns true if the forwarding status was updated.
func (p *protocol) setForwarding(v bool) bool {
if v {
- return atomic.SwapUint32(&p.forwarding, 1) == 0
+ return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
}
- return atomic.SwapUint32(&p.forwarding, 0) == 1
+ return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
}
// SetForwarding implements stack.ForwardingNetworkProtocol.
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index c206cebeb..a620e9ad9 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectDropped int
- expectWritten int
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectOutputDropped int
+ expectPostroutingDropped int
+ expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectDropped: 0,
- expectWritten: nPackets,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectDropped: 0,
- expectWritten: nPackets - 1,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets - 1,
}, {
- name: "Drop all",
+ name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectDropped: nPackets,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: nPackets,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
- name: "Drop some",
+ name: "Drop all with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
// We'll match and DROP the last packet.
@@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectDropped: 1,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 1,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Postrouting DROP rule that matches only 1
+ // of the 3 packets.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 1,
+ expectWritten: nPackets,
},
}
@@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
}
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
- t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
+ t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
- t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index dd153466d..165b7d2d2 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -76,10 +76,29 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error)
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
- _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
return err
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ // As per RFC 2710 section 5 page 10,
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ //
+ // MLD messages are never sent for multicast addresses whose scope is 0
+ // (reserved) or 1 (node-local).
+ if groupAddress == header.IPv6AllNodesMulticastAddress {
+ return false
+ }
+
+ scope := header.V6MulticastScope(groupAddress)
+ return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope
+}
+
// init sets up an mldState struct, and is required to be called before using
// a new mldState.
//
@@ -91,7 +110,6 @@ func (mld *mldState) init(ep *endpoint) {
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
- AllNodesAddress: header.IPv6AllNodesMulticastAddress,
})
}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 85a8f9944..71d1c3e28 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -27,15 +27,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var (
+ linkLocalAddr = testutil.MustParse6("fe80::1")
+ globalAddr = testutil.MustParse6("a80::1")
+ globalMulticastAddr = testutil.MustParse6("ff05:100::2")
+
linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
)
@@ -93,7 +92,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
if p, ok := e.Read(); !ok {
t.Fatal("expected a done message to be sent")
} else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
}
}
@@ -354,10 +353,8 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho
}
func TestMLDPacketValidation(t *testing.T) {
- const (
- nicID = 1
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- )
+ const nicID = 1
+ linkLocalAddr2 := testutil.MustParse6("fe80::2")
tests := []struct {
name string
@@ -464,3 +461,141 @@ func TestMLDPacketValidation(t *testing.T) {
})
}
}
+
+func TestMLDSkipProtocol(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ group tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Reserverd0",
+ group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: false,
+ },
+ {
+ name: "Interface Local",
+ group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: false,
+ },
+ {
+ name: "Link Local",
+ group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Realm Local",
+ group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Admin Local",
+ group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Site Local",
+ group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(6)",
+ group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(7)",
+ group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Organization Local",
+ group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(9)",
+ group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(A)",
+ group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(B)",
+ group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(C)",
+ group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(D)",
+ group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Global",
+ group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "ReservedF",
+ group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil {
+ t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err)
+ }
+ if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group)
+ }
+
+ if !test.expectReport {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p)
+ }
+
+ return
+ }
+
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 536493f87..a110faa54 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -737,7 +737,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
prefix := opt.Subnet()
// Is the prefix a link-local?
- if header.IsV6LinkLocalAddress(prefix.ID()) {
+ if header.IsV6LinkLocalUnicastAddress(prefix.ID()) {
// ...Yes, skip as per RFC 4861 section 6.3.4,
// and RFC 4862 section 5.5.3.b (for SLAAC).
continue
@@ -1703,7 +1703,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// the unspecified address if no address is assigned
// to the sending interface.
localAddr := header.IPv6Any
- if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil {
localAddr = addressEndpoint.AddressWithPrefix().Address
addressEndpoint.DecRef()
}
@@ -1730,7 +1730,7 @@ func (ndp *ndpState) startSolicitingRouters() {
icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpData,
Src: localAddr,
- Dst: header.IPv6AllRoutersMulticastAddress,
+ Dst: header.IPv6AllRoutersLinkLocalMulticastAddress,
}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1739,14 +1739,14 @@ func (ndp *ndpState) startSolicitingRouters() {
})
sent := ndp.ep.stats.icmp.packetsSent
- if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
}, nil /* extensionHeaders */); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.dropped.Increment()
// Don't send any more messages if we had an error.
remaining = 0
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index ecd5003a7..1b96b1fb8 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -30,22 +30,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- stackIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
defaultIPv4PrefixLength = 24
- linkLocalIPv6Addr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalIPv6Addr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
-
- ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
- ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
- ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
- ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
- ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
@@ -59,6 +50,19 @@ const (
)
var (
+ stackIPv4Addr = testutil.MustParse4("10.0.0.1")
+ linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1")
+ linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2")
+
+ ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3")
+ ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4")
+ ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5")
+ ipv6MulticastAddr1 = testutil.MustParse6("ff02::3")
+ ipv6MulticastAddr2 = testutil.MustParse6("ff02::4")
+ ipv6MulticastAddr3 = testutil.MustParse6("ff02::5")
+)
+
+var (
// unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
// NIC will wait before sending an unsolicited report after joining a
// multicast group, in deciseconds.
@@ -194,7 +198,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
} else {
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC)
}
// Should not send any more packets.
@@ -606,7 +610,7 @@ func TestMGPLeaveGroup(t *testing.T) {
validateLeave: func(t *testing.T, p channel.PacketInfo) {
t.Helper()
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
},
checkInitialGroups: checkInitialIPv6Groups,
},
@@ -1014,7 +1018,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
t.Helper()
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr)
},
getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
t.Helper()
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 210262703..b7f6d52ae 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -21,6 +21,7 @@ go_test(
library = ":ports",
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 678199371..b5b013b64 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -17,6 +17,7 @@
package ports
import (
+ "math"
"math/rand"
"sync/atomic"
@@ -24,7 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-const anyIPAddress tcpip.Address = ""
+const (
+ firstEphemeral = 16000
+ anyIPAddress tcpip.Address = ""
+)
// Reservation describes a port reservation.
type Reservation struct {
@@ -220,10 +224,8 @@ type PortManager struct {
func NewPortManager() *PortManager {
return &PortManager{
allocatedPorts: make(map[portDescriptor]addrToDevice),
- // Match Linux's default ephemeral range. See:
- // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842
- firstEphemeral: 32768,
- numEphemeral: 28232,
+ firstEphemeral: firstEphemeral,
+ numEphemeral: math.MaxUint16 - firstEphemeral + 1,
}
}
@@ -242,13 +244,13 @@ func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err
numEphemeral := pm.numEphemeral
pm.ephemeralMu.RUnlock()
- offset := uint16(rand.Int31n(int32(numEphemeral)))
+ offset := uint32(rand.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
// portHint atomically reads and returns the pm.hint value.
-func (pm *PortManager) portHint() uint16 {
- return uint16(atomic.LoadUint32(&pm.hint))
+func (pm *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&pm.hint)
}
// incPortHint atomically increments pm.hint by 1.
@@ -260,7 +262,7 @@ func (pm *PortManager) incPortHint() {
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) {
pm.ephemeralMu.RLock()
firstEphemeral := pm.firstEphemeral
numEphemeral := pm.numEphemeral
@@ -277,9 +279,9 @@ func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTeste
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
- for i := uint16(0); i < count; i++ {
- port = first + (offset+i)%count
+func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ for i := uint32(0); i < uint32(count); i++ {
+ port := uint16(uint32(first) + (offset+i)%uint32(count))
ok, err := testPort(port)
if err != nil {
return 0, err
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 0f43dc8f8..6c4fb8c68 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,19 +15,23 @@
package ports
import (
+ "math"
"math/rand"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
fakeNetworkNumber tcpip.NetworkProtocolNumber = 2
+)
- fakeIPAddress = tcpip.Address("\x08\x08\x08\x08")
- fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
+var (
+ fakeIPAddress = testutil.MustParse4("8.8.8.8")
+ fakeIPAddress1 = testutil.MustParse4("8.8.8.9")
)
type portReserveTestAction struct {
@@ -479,7 +483,7 @@ func TestPickEphemeralPortStable(t *testing.T) {
if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
t.Fatalf("failed to set ephemeral port range: %s", err)
}
- portOffset := uint16(rand.Int31n(int32(numEphemeralPorts)))
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
port, err := pm.PickEphemeralPortStable(portOffset, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
@@ -490,3 +494,29 @@ func TestPickEphemeralPortStable(t *testing.T) {
})
}
}
+
+// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a
+// port allocation failure.
+func TestOverflow(t *testing.T) {
+ // Use a small range and start at offsets that will cause an overflow.
+ count := uint16(50)
+ for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ {
+ reservedPorts := make(map[uint16]struct{})
+ // Ensure we can reserve everything in the allowed range.
+ for i := uint16(0); i < count; i++ {
+ port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) {
+ if _, ok := reservedPorts[port]; !ok {
+ reservedPorts[port] = struct{}{}
+ return true, nil
+ }
+ return false, nil
+ })
+ if err != nil {
+ t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts))
+ }
+ if port < firstEphemeral || port > firstEphemeral+count {
+ t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index dc37e61a4..a6c877158 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -58,6 +58,9 @@ type SocketOptionsHandler interface {
// changed. The handler is invoked with the new value for the socket send
// buffer size. It also returns the newly set value.
OnSetSendBufferSize(v int64) (newSz int64)
+
+ // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE.
+ OnSetReceiveBufferSize(v, oldSz int64) (newSz int64)
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -99,6 +102,11 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
return v
}
+// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize.
+func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) {
+ return v
+}
+
// StackHandler holds methods to access the stack options. These must be
// implemented by the stack.
type StackHandler interface {
@@ -207,6 +215,14 @@ type SocketOptions struct {
// sendBufferSize determines the send buffer size for this socket.
sendBufferSize int64
+ // getReceiveBufferLimits provides the handler to get the min, default and
+ // max size for receive buffer. It is initialized at the creation time and
+ // will not change.
+ getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"`
+
+ // receiveBufferSize determines the receive buffer size for this socket.
+ receiveBufferSize int64
+
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -217,10 +233,11 @@ type SocketOptions struct {
// InitHandler initializes the handler. This must be called before using the
// socket options utility.
-func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) {
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits, getReceiveBufferLimits GetReceiveBufferLimits) {
so.handler = handler
so.stackHandler = stack
so.getSendBufferLimits = getSendBufferLimits
+ so.getReceiveBufferLimits = getReceiveBufferLimits
}
func storeAtomicBool(addr *uint32, v bool) {
@@ -632,3 +649,42 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
newSz := so.handler.OnSetSendBufferSize(v)
atomic.StoreInt64(&so.sendBufferSize, newSz)
}
+
+// GetReceiveBufferSize gets value for SO_RCVBUF option.
+func (so *SocketOptions) GetReceiveBufferSize() int64 {
+ return atomic.LoadInt64(&so.receiveBufferSize)
+}
+
+// SetReceiveBufferSize sets value for SO_RCVBUF option.
+func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
+ if !notify {
+ atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize)
+ return
+ }
+
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ v := receiveBufferSize
+ ss := so.getReceiveBufferLimits(so.stackHandler)
+ min := int64(ss.Min)
+ max := int64(ss.Max)
+ // Validate the send buffer size with min and max values.
+ if v > max {
+ v = max
+ }
+
+ // Multiply it by factor of 2.
+ if v < math.MaxInt32/PacketOverheadFactor {
+ v *= PacketOverheadFactor
+ if v < min {
+ v = min
+ }
+ } else {
+ v = math.MaxInt32
+ }
+
+ oldSz := atomic.LoadInt64(&so.receiveBufferSize)
+ // Notify endpoint about change in buffer size.
+ newSz := so.handler.OnSetReceiveBufferSize(v, oldSz)
+ atomic.StoreInt64(&so.receiveBufferSize, newSz)
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 49362333a..2bd6a67f5 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -45,6 +45,7 @@ go_library(
"addressable_endpoint_state.go",
"conntrack.go",
"headertype_string.go",
+ "hook_string.go",
"icmp_rate_limit.go",
"iptables.go",
"iptables_state.go",
@@ -66,6 +67,7 @@ go_library(
"stack.go",
"stack_global_state.go",
"stack_options.go",
+ "tcp.go",
"transport_demuxer.go",
"tuple_list.go",
],
@@ -115,6 +117,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -139,6 +142,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 3f083928f..41e964cf3 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -16,6 +16,7 @@ package stack
import (
"encoding/binary"
+ "fmt"
"sync"
"time"
@@ -29,7 +30,7 @@ import (
// The connection is created for a packet if it does not exist. Every
// connection contains two tuples (original and reply). The tuples are
// manipulated if there is a matching NAT rule. The packet is modified by
-// looking at the tuples in the Prerouting and Output hooks.
+// looking at the tuples in each hook.
//
// Currently, only TCP tracking is supported.
@@ -46,12 +47,14 @@ const (
)
// Manipulation type for the connection.
+// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
+// DNAT at the same time.
type manipType int
const (
manipNone manipType = iota
- manipDstPrerouting
- manipDstOutput
+ manipSource
+ manipDestination
)
// tuple holds a connection's identifying and manipulating data in one
@@ -108,6 +111,7 @@ type conn struct {
reply tuple
// manip indicates if the packet should be manipulated. It is immutable.
+ // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
manip manipType
// tcbHook indicates if the packet is inbound or outbound to
@@ -124,6 +128,18 @@ type conn struct {
lastUsed time.Time `state:".(unixTime)"`
}
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
+}
+
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
@@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
}, nil
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
func (ct *ConnTrack) init() {
ct.mu.Lock()
defer ct.mu.Unlock()
@@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1
return nil
}
- // Create a new connection and change the port as per the iptables
- // rule. This tuple will be used to manipulate the packet in
- // handlePacket.
replyTID := tid.reply()
replyTID.srcAddr = address
replyTID.srcPort = port
- var manip manipType
- switch hook {
- case Prerouting:
- manip = manipDstPrerouting
- case Output:
- manip = manipDstOutput
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
}
- conn := newConn(tid, replyTID, manip, hook)
+ conn = newConn(tid, replyTID, manipDestination, hook)
+ ct.insertConn(conn)
+ return conn
+}
+
+func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Input && hook != Postrouting {
+ return nil
+ }
+
+ replyTID := tid.reply()
+ replyTID.dstAddr = address
+ replyTID.dstPort = port
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
+ }
+ conn = newConn(tid, replyTID, manipSource, hook)
ct.insertConn(conn)
return conn
}
@@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Now that we hold the locks, ensure the tuple hasn't been inserted by
// another thread.
+ // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
alreadyInserted := false
for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
if other.tupleID == conn.original.tupleID {
@@ -343,86 +369,6 @@ func (ct *ConnTrack) insertConn(conn *conn) {
}
}
-// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
-func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For prerouting redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified.
- switch dir {
- case dirOriginal:
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- case dirReply:
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
- // on inbound packets, so we don't recalculate them. However, we should
- // support cases when they are validated, e.g. when we can't offload
- // receive checksumming.
-
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
-// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For output redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified. For prerouting redirection, we only reach this point
- // when replying, so packet sources are modified.
- if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- } else {
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- if gso != nil && gso.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
- } else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
- }
-
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
// handlePacket will manipulate the port and address of the packet if the
// connection exists. Returns whether, after the packet traverses the tables,
// it should create a new entry in the table.
@@ -431,7 +377,9 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
return false
}
- if hook != Prerouting && hook != Output {
+ switch hook {
+ case Prerouting, Input, Output, Postrouting:
+ default:
return false
}
@@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
conn, dir := ct.connFor(pkt)
- // Connection or Rule not found for the packet.
+ // Connection not found for the packet.
if conn == nil {
- return true
+ // If this is the last hook in the data path for this packet (Input if
+ // incoming, Postrouting if outgoing), indicate that a connection should be
+ // inserted by the end of this hook.
+ return hook == Input || hook == Postrouting
}
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return false
}
+ // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
+ // validated if checksum offloading is off. It may require IP defrag if the
+ // packets are fragmented.
+
+ switch hook {
+ case Prerouting, Output:
+ if conn.manip == manipDestination {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetDestinationPort(conn.reply.srcPort)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ case dirReply:
+ tcpHeader.SetSourcePort(conn.original.dstPort)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+ pkt.NatDone = true
+ }
+ case Input, Postrouting:
+ if conn.manip == manipSource {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetSourcePort(conn.reply.dstPort)
+ netHeader.SetSourceAddress(conn.reply.dstAddr)
+ case dirReply:
+ tcpHeader.SetDestinationPort(conn.original.srcPort)
+ netHeader.SetDestinationAddress(conn.original.srcAddr)
+ }
+ pkt.NatDone = true
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+ if !pkt.NatDone {
+ return false
+ }
+
switch hook {
- case Prerouting:
- handlePacketPrerouting(pkt, conn, dir)
- case Output:
- handlePacketOutput(pkt, conn, gso, r, dir)
+ case Prerouting, Input:
+ case Output, Postrouting:
+ // Calculate the TCP checksum and set it.
+ tcpHeader.SetChecksum(0)
+ length := uint16(len(tcpHeader) + pkt.Data().Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ if gso != nil && gso.NeedsCsum {
+ tcpHeader.SetChecksum(xsum)
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
- pkt.NatDone = true
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
@@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
if conn == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip == manipNone {
- // Unmanipulated connection.
+ } else if conn.manip != manipDestination {
+ // Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go
new file mode 100644
index 000000000..3dc8a7b02
--- /dev/null
+++ b/pkg/tcpip/stack/hook_string.go
@@ -0,0 +1,41 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type Hook ."; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Prerouting-0]
+ _ = x[Input-1]
+ _ = x[Forward-2]
+ _ = x[Output-3]
+ _ = x[Postrouting-4]
+ _ = x[NumHooks-5]
+}
+
+const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
+
+var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
+
+func (i Hook) String() string {
+ if i >= Hook(len(_Hook_index)-1) {
+ return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 52890f6eb..7ea87d325 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -175,9 +175,10 @@ func DefaultTables() *IPTables {
},
},
priorities: [NumHooks][]TableID{
- Prerouting: {MangleID, NATID},
- Input: {NATID, FilterID},
- Output: {MangleID, NATID, FilterID},
+ Prerouting: {MangleID, NATID},
+ Input: {NATID, FilterID},
+ Output: {MangleID, NATID, FilterID},
+ Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
seed: generateRandUint32(),
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 0e8b90c9b..317efe754 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -182,3 +182,81 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleAccept, 0
}
+
+// SNATTarget modifies the source port/IP in the outgoing packets.
+type SNATTarget struct {
+ Addr tcpip.Address
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ // Packet is already manipulated.
+ if pkt.NatDone {
+ return RuleAccept, 0
+ }
+
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
+ return RuleDrop, 0
+ }
+
+ switch hook {
+ case Postrouting, Input:
+ case Prerouting, Output, Forward:
+ panic(fmt.Sprintf("%s not supported", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ udpHeader := header.UDP(pkt.TransportHeader().View())
+ udpHeader.SetChecksum(0)
+ udpHeader.SetSourcePort(st.Port)
+ netHeader := pkt.Network()
+ netHeader.SetSourceAddress(st.Addr)
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.RequiresTXTransportChecksum() {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
+ pkt.NatDone = true
+ case header.TCPProtocolNumber:
+ if ct == nil {
+ return RuleAccept, 0
+ }
+
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, gso, r)
+ }
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 14124ae66..b6cf24739 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -33,15 +33,19 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
+var (
+ addr1 = testutil.MustParse6("a00::1")
+ addr2 = testutil.MustParse6("a00::2")
+ addr3 = testutil.MustParse6("a00::3")
+)
+
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
@@ -1390,7 +1394,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
// configured not to.
func TestNoPrefixDiscovery(t *testing.T) {
prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ Address: testutil.MustParse6("102:304:506:708::"),
PrefixLen: 64,
}
@@ -1590,7 +1594,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}()
prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ Address: testutil.MustParse6("102:304:506:708::"),
PrefixLen: 64,
}
subnet := prefix.Subnet()
@@ -5204,13 +5208,13 @@ func TestRouterSolicitation(t *testing.T) {
}
// Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
)
@@ -5362,7 +5366,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS())
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index bb2b2d705..1d39ee73d 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -26,14 +26,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
entryTestNICID tcpip.NICID = 1
- entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
@@ -44,6 +43,11 @@ const (
entryTestNetDefaultMTU = 65536
)
+var (
+ entryTestAddr1 = testutil.MustParse6("a::1")
+ entryTestAddr2 = testutil.MustParse6("a::2")
+)
+
// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
// time.
func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 8f288675d..c10304d5f 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network {
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
- return NewPacketBuffer(PacketBufferOptions{
+ newPk := NewPacketBuffer(PacketBufferOptions{
Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
})
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ if pk.NatDone {
+ newPk.NatDone = true
+ }
+ return newPk
}
// headerInfo stores metadata about a header in a packet.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 39344808d..4ae6bed5a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
localAddr = addressEndpoint.AddressWithPrefix().Address
}
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) {
addressEndpoint.DecRef()
return nil
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 931a97ddc..21cfbad71 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -35,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -56,306 +55,6 @@ type transportProtocolState struct {
defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
-// TCPProbeFunc is the expected function type for a TCP probe function to be
-// passed to stack.AddTCPProbe.
-type TCPProbeFunc func(s TCPEndpointState)
-
-// TCPCubicState is used to hold a copy of the internal cubic state when the
-// TCPProbeFunc is invoked.
-type TCPCubicState struct {
- WLastMax float64
- WMax float64
- T time.Time
- TimeSinceLastCongestion time.Duration
- C float64
- K float64
- Beta float64
- WC float64
- WEst float64
-}
-
-// TCPRACKState is used to hold a copy of the internal RACK state when the
-// TCPProbeFunc is invoked.
-type TCPRACKState struct {
- XmitTime time.Time
- EndSequence seqnum.Value
- FACK seqnum.Value
- RTT time.Duration
- Reord bool
- DSACKSeen bool
- ReoWnd time.Duration
- ReoWndIncr uint8
- ReoWndPersist int8
- RTTSeq seqnum.Value
-}
-
-// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
-type TCPEndpointID struct {
- // LocalPort is the local port associated with the endpoint.
- LocalPort uint16
-
- // LocalAddress is the local [network layer] address associated with
- // the endpoint.
- LocalAddress tcpip.Address
-
- // RemotePort is the remote port associated with the endpoint.
- RemotePort uint16
-
- // RemoteAddress it the remote [network layer] address associated with
- // the endpoint.
- RemoteAddress tcpip.Address
-}
-
-// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
-// TCP endpoint.
-type TCPFastRecoveryState struct {
- // Active if true indicates the endpoint is in fast recovery.
- Active bool
-
- // First is the first unacknowledged sequence number being recovered.
- First seqnum.Value
-
- // Last is the 'recover' sequence number that indicates the point at
- // which we should exit recovery barring any timeouts etc.
- Last seqnum.Value
-
- // MaxCwnd is the maximum value we are permitted to grow the congestion
- // window during recovery. This is set at the time we enter recovery.
- MaxCwnd int
-
- // HighRxt is the highest sequence number which has been retransmitted
- // during the current loss recovery phase.
- // See: RFC 6675 Section 2 for details.
- HighRxt seqnum.Value
-
- // RescueRxt is the highest sequence number which has been
- // optimistically retransmitted to prevent stalling of the ACK clock
- // when there is loss at the end of the window and no new data is
- // available for transmission.
- // See: RFC 6675 Section 2 for details.
- RescueRxt seqnum.Value
-}
-
-// TCPReceiverState holds a copy of the internal state of the receiver for
-// a given TCP endpoint.
-type TCPReceiverState struct {
- // RcvNxt is the TCP variable RCV.NXT.
- RcvNxt seqnum.Value
-
- // RcvAcc is the TCP variable RCV.ACC.
- RcvAcc seqnum.Value
-
- // RcvWndScale is the window scaling to use for inbound segments.
- RcvWndScale uint8
-
- // PendingBufUsed is the number of bytes pending in the receive
- // queue.
- PendingBufUsed int
-}
-
-// TCPSenderState holds a copy of the internal state of the sender for
-// a given TCP Endpoint.
-type TCPSenderState struct {
- // LastSendTime is the time at which we sent the last segment.
- LastSendTime time.Time
-
- // DupAckCount is the number of Duplicate ACK's received.
- DupAckCount int
-
- // SndCwnd is the size of the sending congestion window in packets.
- SndCwnd int
-
- // Ssthresh is the slow start threshold in packets.
- Ssthresh int
-
- // SndCAAckCount is the number of packets consumed in congestion
- // avoidance mode.
- SndCAAckCount int
-
- // Outstanding is the number of packets in flight.
- Outstanding int
-
- // SackedOut is the number of packets which have been selectively acked.
- SackedOut int
-
- // SndWnd is the send window size in bytes.
- SndWnd seqnum.Size
-
- // SndUna is the next unacknowledged sequence number.
- SndUna seqnum.Value
-
- // SndNxt is the sequence number of the next segment to be sent.
- SndNxt seqnum.Value
-
- // RTTMeasureSeqNum is the sequence number being used for the latest RTT
- // measurement.
- RTTMeasureSeqNum seqnum.Value
-
- // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
- RTTMeasureTime time.Time
-
- // Closed indicates that the caller has closed the endpoint for sending.
- Closed bool
-
- // SRTT is the smoothed round-trip time as defined in section 2 of
- // RFC 6298.
- SRTT time.Duration
-
- // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
- RTO time.Duration
-
- // RTTVar is the round-trip time variation as defined in section 2 of
- // RFC 6298.
- RTTVar time.Duration
-
- // SRTTInited if true indicates take a valid RTT measurement has been
- // completed.
- SRTTInited bool
-
- // MaxPayloadSize is the maximum size of the payload of a given segment.
- // It is initialized on demand.
- MaxPayloadSize int
-
- // SndWndScale is the number of bits to shift left when reading the send
- // window size from a segment.
- SndWndScale uint8
-
- // MaxSentAck is the highest acknowledgement number sent till now.
- MaxSentAck seqnum.Value
-
- // FastRecovery holds the fast recovery state for the endpoint.
- FastRecovery TCPFastRecoveryState
-
- // Cubic holds the state related to CUBIC congestion control.
- Cubic TCPCubicState
-
- // RACKState holds the state related to RACK loss detection algorithm.
- RACKState TCPRACKState
-}
-
-// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
-type TCPSACKInfo struct {
- // Blocks is the list of SACK Blocks that identify the out of order segments
- // held by a given TCP endpoint.
- Blocks []header.SACKBlock
-
- // ReceivedBlocks are the SACK blocks received by this endpoint
- // from the peer endpoint.
- ReceivedBlocks []header.SACKBlock
-
- // MaxSACKED is the highest sequence number that has been SACKED
- // by the peer.
- MaxSACKED seqnum.Value
-}
-
-// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
-type RcvBufAutoTuneParams struct {
- // MeasureTime is the time at which the current measurement
- // was started.
- MeasureTime time.Time
-
- // CopiedBytes is the number of bytes copied to user space since
- // this measure began.
- CopiedBytes int
-
- // PrevCopiedBytes is the number of bytes copied to userspace in
- // the previous RTT period.
- PrevCopiedBytes int
-
- // RcvBufSize is the auto tuned receive buffer size.
- RcvBufSize int
-
- // RTT is the smoothed RTT as measured by observing the time between
- // when a byte is first acknowledged and the receipt of data that is at
- // least one window beyond the sequence number that was acknowledged.
- RTT time.Duration
-
- // RTTVar is the "round-trip time variation" as defined in section 2
- // of RFC6298.
- RTTVar time.Duration
-
- // RTTMeasureSeqNumber is the highest acceptable sequence number at the
- // time this RTT measurement period began.
- RTTMeasureSeqNumber seqnum.Value
-
- // RTTMeasureTime is the absolute time at which the current RTT
- // measurement period began.
- RTTMeasureTime time.Time
-
- // Disabled is true if an explicit receive buffer is set for the
- // endpoint.
- Disabled bool
-}
-
-// TCPEndpointState is a copy of the internal state of a TCP endpoint.
-type TCPEndpointState struct {
- // ID is a copy of the TransportEndpointID for the endpoint.
- ID TCPEndpointID
-
- // SegTime denotes the absolute time when this segment was received.
- SegTime time.Time
-
- // RcvBufSize is the size of the receive socket buffer for the endpoint.
- RcvBufSize int
-
- // RcvBufUsed is the amount of bytes actually held in the receive socket
- // buffer for the endpoint.
- RcvBufUsed int
-
- // RcvBufAutoTuneParams is used to hold state variables to compute
- // the auto tuned receive buffer size.
- RcvAutoParams RcvBufAutoTuneParams
-
- // RcvClosed if true, indicates the endpoint has been closed for reading.
- RcvClosed bool
-
- // SendTSOk is used to indicate when the TS Option has been negotiated.
- // When sendTSOk is true every non-RST segment should carry a TS as per
- // RFC7323#section-1.1.
- SendTSOk bool
-
- // RecentTS is the timestamp that should be sent in the TSEcr field of
- // the timestamp for future segments sent by the endpoint. This field is
- // updated if required when a new segment is received by this endpoint.
- RecentTS uint32
-
- // TSOffset is a randomized offset added to the value of the TSVal field
- // in the timestamp option.
- TSOffset uint32
-
- // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
- // option in the SYN/SYN-ACK.
- SACKPermitted bool
-
- // SACK holds TCP SACK related information for this endpoint.
- SACK TCPSACKInfo
-
- // SndBufSize is the size of the socket send buffer.
- SndBufSize int
-
- // SndBufUsed is the number of bytes held in the socket send buffer.
- SndBufUsed int
-
- // SndClosed indicates that the endpoint has been closed for sends.
- SndClosed bool
-
- // SndBufInQueue is the number of bytes in the send queue.
- SndBufInQueue seqnum.Size
-
- // PacketTooBigCount is used to notify the main protocol routine how
- // many times a "packet too big" control packet is received.
- PacketTooBigCount int
-
- // SndMTU is the smallest MTU seen in the control packets received.
- SndMTU int
-
- // Receiver holds variables related to the TCP receiver for the endpoint.
- Receiver TCPReceiverState
-
- // Sender holds state related to the TCP Sender for the endpoint.
- Sender TCPSenderState
-}
-
// ResumableEndpoint is an endpoint that needs to be resumed after restore.
type ResumableEndpoint interface {
// Resume resumes an endpoint after restore. This can be used to restart
@@ -455,7 +154,7 @@ type Stack struct {
// receiveBufferSize holds the min/default/max receive buffer sizes for
// endpoints other than TCP.
- receiveBufferSize ReceiveBufferSizeOption
+ receiveBufferSize tcpip.ReceiveBufferSizeOption
// tcpInvalidRateLimit is the maximal rate for sending duplicate
// acknowledgements in response to incoming TCP packets that are for an existing
@@ -669,7 +368,7 @@ func New(opts Options) *Stack {
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
},
- receiveBufferSize: ReceiveBufferSizeOption{
+ receiveBufferSize: tcpip.ReceiveBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
@@ -1344,7 +1043,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
+ isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
@@ -1381,7 +1080,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
index dfec4258a..33824afd0 100644
--- a/pkg/tcpip/stack/stack_global_state.go
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -14,6 +14,78 @@
package stack
+import "time"
+
// StackFromEnv is the global stack created in restore run.
// FIXME(b/36201077)
var StackFromEnv *Stack
+
+// saveT is invoked by stateify.
+func (t *TCPCubicState) saveT() unixTime {
+ return unixTime{t.T.Unix(), t.T.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (t *TCPCubicState) loadT(unix unixTime) {
+ t.T = time.Unix(unix.second, unix.nano)
+}
+
+// saveXmitTime is invoked by stateify.
+func (t *TCPRACKState) saveXmitTime() unixTime {
+ return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (t *TCPRACKState) loadXmitTime(unix unixTime) {
+ t.XmitTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveLastSendTime is invoked by stateify.
+func (t *TCPSenderState) saveLastSendTime() unixTime {
+ return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (t *TCPSenderState) loadLastSendTime(unix unixTime) {
+ t.LastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRTTMeasureTime is invoked by stateify.
+func (t *TCPSenderState) saveRTTMeasureTime() unixTime {
+ return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()}
+}
+
+// loadRTTMeasureTime is invoked by stateify.
+func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) {
+ t.RTTMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime {
+ return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()}
+}
+
+// loadMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
+ r.MeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRTTMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime {
+ return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()}
+}
+
+// loadRTTMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) {
+ r.RTTMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveSegTime is invoked by stateify.
+func (t *TCPEndpointState) saveSegTime() unixTime {
+ return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()}
+}
+
+// loadSegTime is invoked by stateify.
+func (t *TCPEndpointState) loadSegTime(unix unixTime) {
+ t.SegTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
index 3066f4ffd..80e8e0089 100644
--- a/pkg/tcpip/stack/stack_options.go
+++ b/pkg/tcpip/stack/stack_options.go
@@ -68,7 +68,7 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error {
s.mu.Unlock()
return nil
- case ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
@@ -107,7 +107,7 @@ func (s *Stack) Option(option interface{}) tcpip.Error {
s.mu.RUnlock()
return nil
- case *ReceiveBufferSizeOption:
+ case *tcpip.ReceiveBufferSizeOption:
s.mu.RLock()
*v = s.receiveBufferSize
s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 2814b94b4..a0bd69d9a 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -39,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
@@ -1645,10 +1646,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
// Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
- nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ nic1Gateway := testutil.MustParse4("192.168.1.1")
// Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
- nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01")
+ nic2Gateway := testutil.MustParse4("10.10.10.1")
// Create a new stack with two NICs.
s := stack.New(stack.Options{
@@ -2789,25 +2790,27 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
const (
- linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01")
- ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02")
- toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
-
nicID = 1
lifetimeSeconds = 9999
)
+ var (
+ linkLocalAddr1 = testutil.MustParse6("fe80::1")
+ linkLocalAddr2 = testutil.MustParse6("fe80::2")
+ linkLocalMulticastAddr = testutil.MustParse6("ff02::1")
+ uniqueLocalAddr1 = testutil.MustParse6("fc00::1")
+ uniqueLocalAddr2 = testutil.MustParse6("fd00::2")
+ globalAddr1 = testutil.MustParse6("a000::1")
+ globalAddr2 = testutil.MustParse6("a000::2")
+ globalAddr3 = testutil.MustParse6("a000::3")
+ ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1")
+ ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2")
+ toredoAddr1 = testutil.MustParse6("2001::1")
+ toredoAddr2 = testutil.MustParse6("2001::2")
+ ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1")
+ ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2")
+ )
+
prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1)
@@ -3354,21 +3357,21 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
const sMin = stack.MinBufferSize
testCases := []struct {
name string
- rs stack.ReceiveBufferSizeOption
+ rs tcpip.ReceiveBufferSizeOption
err tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
- {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
// Valid Configurations
- {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@@ -3377,7 +3380,7 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
if err := s.SetOption(tc.rs); err != tc.err {
t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if tc.err == nil {
if err := s.Option(&rs); err != nil {
t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
@@ -3448,7 +3451,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
ipv4Subnet := ipv4Addr.Subnet()
ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4Gateway := testutil.MustParse4("192.168.1.1")
ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
Address: "\xc0\xa8\x01\x3a",
PrefixLen: 31,
@@ -4352,13 +4355,15 @@ func TestWritePacketToRemote(t *testing.T) {
func TestClearNeighborCacheOnNICDisable(t *testing.T) {
const (
- nicID = 1
-
- ipv4Addr = tcpip.Address("\x01\x02\x03\x04")
- ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04")
+ nicID = 1
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
)
+ var (
+ ipv4Addr = testutil.MustParse4("1.2.3.4")
+ ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304")
+ )
+
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
new file mode 100644
index 000000000..ddff6e2d6
--- /dev/null
+++ b/pkg/tcpip/stack/tcp.go
@@ -0,0 +1,451 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPCubicState is used to hold a copy of the internal cubic state when the
+// TCPProbeFunc is invoked.
+//
+// +stateify savable
+type TCPCubicState struct {
+ // WLastMax is the previous wMax value.
+ WLastMax float64
+
+ // WMax is the value of the congestion window at the time of the last
+ // congestion event.
+ WMax float64
+
+ // T is the time when the current congestion avoidance was entered.
+ T time.Time `state:".(unixTime)"`
+
+ // TimeSinceLastCongestion denotes the time since the current
+ // congestion avoidance was entered.
+ TimeSinceLastCongestion time.Duration
+
+ // C is the cubic constant as specified in RFC8312, page 11.
+ C float64
+
+ // K is the time period (in seconds) that the above function takes to
+ // increase the current window size to WMax if there are no further
+ // congestion events and is calculated using the following equation:
+ //
+ // K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5)
+ K float64
+
+ // Beta is the CUBIC multiplication decrease factor. That is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // WC(0)=WMax*beta_cubic.
+ Beta float64
+
+ // WC is window computed by CUBIC at time TimeSinceLastCongestion. It's
+ // calculated using the formula:
+ //
+ // WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1)
+ WC float64
+
+ // WEst is the window computed by CUBIC at time
+ // TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT).
+ WEst float64
+}
+
+// TCPRACKState is used to hold a copy of the internal RACK state when the
+// TCPProbeFunc is invoked.
+//
+// +stateify savable
+type TCPRACKState struct {
+ // XmitTime is the transmission timestamp of the most recent
+ // acknowledged segment.
+ XmitTime time.Time `state:".(unixTime)"`
+
+ // EndSequence is the ending TCP sequence number of the most recent
+ // acknowledged segment.
+ EndSequence seqnum.Value
+
+ // FACK is the highest selectively or cumulatively acknowledged
+ // sequence.
+ FACK seqnum.Value
+
+ // RTT is the round trip time of the most recently delivered packet on
+ // the connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ RTT time.Duration
+
+ // Reord is true iff reordering has been detected on this connection.
+ Reord bool
+
+ // DSACKSeen is true iff the connection has seen a DSACK.
+ DSACKSeen bool
+
+ // ReoWnd is the reordering window time used for recording packet
+ // transmission times. It is used to defer the moment at which RACK
+ // marks a packet lost.
+ ReoWnd time.Duration
+
+ // ReoWndIncr is the multiplier applied to adjust reorder window.
+ ReoWndIncr uint8
+
+ // ReoWndPersist is the number of loss recoveries before resetting
+ // reorder window.
+ ReoWndPersist int8
+
+ // RTTSeq is the SND.NXT when RTT is updated.
+ RTTSeq seqnum.Value
+}
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+//
+// +stateify savable
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+//
+// +stateify savable
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery. The
+ // following fields are only meaningful when Active is true.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ // It exists to avoid attacks where the receiver intentionally sends
+ // duplicate acks to artificially inflate the sender's cwnd.
+ MaxCwnd int
+
+ // HighRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase. See: RFC 6675 Section 2 for
+ // details.
+ HighRxt seqnum.Value
+
+ // RescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission. See: RFC 6675 Section 2 for details.
+ RescueRxt seqnum.Value
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for a
+// given TCP endpoint.
+//
+// +stateify savable
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to its
+ // peer that it's willing to accept. This may be different than RcvNxt
+ // + (last advertised receive window) if the receive window is reduced;
+ // in that case we have to reduce the window as we receive more data
+ // instead of shrinking it.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive queue.
+ PendingBufUsed int
+}
+
+// TCPRTTState holds a copy of information about the endpoint's round trip
+// time.
+//
+// +stateify savable
+type TCPRTTState struct {
+ // SRTT is the smoothed round trip time defined in section 2 of RFC
+ // 6298.
+ SRTT time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates that a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for a given
+// TCP Endpoint.
+//
+// +stateify savable
+type TCPSenderState struct {
+ // LastSendTime is the timestamp at which we sent the last segment.
+ LastSendTime time.Time `state:".(unixTime)"`
+
+ // DupAckCount is the number of Duplicate ACKs received. It is used for
+ // fast retransmit.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the threshold between slow start and congestion
+ // avoidance.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets acknowledged during
+ // congestion avoidance. When enough packets have been ack'd (typically
+ // cwnd packets), the congestion window is incremented by one.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets that have been sent but not yet
+ // acknowledged.
+ Outstanding int
+
+ // SackedOut is the number of packets which have been selectively
+ // acked.
+ SackedOut int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest
+ // RTT measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time `state:".(unixTime)"`
+
+ // Closed indicates that the caller has closed the endpoint for
+ // sending.
+ Closed bool
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC
+ // 6298.
+ RTO time.Duration
+
+ // RTTState holds information about the endpoint's round trip time.
+ RTTState TCPRTTState
+
+ // MaxPayloadSize is the maximum size of the payload of a given
+ // segment. It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the
+ // send window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgement number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+
+ // Cubic holds the state related to CUBIC congestion control.
+ Cubic TCPCubicState
+
+ // RACKState holds the state related to RACK loss detection algorithm.
+ RACKState TCPRACKState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+//
+// +stateify savable
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK Blocks that identify the out of order
+ // segments held by a given TCP endpoint.
+ Blocks []header.SACKBlock
+
+ // ReceivedBlocks are the SACK blocks received by this endpoint from
+ // the peer endpoint.
+ ReceivedBlocks []header.SACKBlock
+
+ // MaxSACKED is the highest sequence number that has been SACKED by the
+ // peer.
+ MaxSACKED seqnum.Value
+}
+
+// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
+//
+// +stateify savable
+type RcvBufAutoTuneParams struct {
+ // MeasureTime is the time at which the current measurement was
+ // started.
+ MeasureTime time.Time `state:".(unixTime)"`
+
+ // CopiedBytes is the number of bytes copied to user space since this
+ // measure began.
+ CopiedBytes int
+
+ // PrevCopiedBytes is the number of bytes copied to userspace in the
+ // previous RTT period.
+ PrevCopiedBytes int
+
+ // RcvBufSize is the auto tuned receive buffer size.
+ RcvBufSize int
+
+ // RTT is the smoothed RTT as measured by observing the time between
+ // when a byte is first acknowledged and the receipt of data that is at
+ // least one window beyond the sequence number that was acknowledged.
+ RTT time.Duration
+
+ // RTTVar is the "round-trip time variation" as defined in section 2 of
+ // RFC6298.
+ RTTVar time.Duration
+
+ // RTTMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ RTTMeasureSeqNumber seqnum.Value
+
+ // RTTMeasureTime is the absolute time at which the current RTT
+ // measurement period began.
+ RTTMeasureTime time.Time `state:".(unixTime)"`
+
+ // Disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ Disabled bool
+}
+
+// TCPRcvBufState contains information about the state of an endpoint's receive
+// socket buffer.
+//
+// +stateify savable
+type TCPRcvBufState struct {
+ // RcvBufUsed is the amount of bytes actually held in the receive
+ // socket buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvBufAutoTuneParams is used to hold state variables to compute the
+ // auto tuned receive buffer size.
+ RcvAutoParams RcvBufAutoTuneParams
+
+ // RcvClosed if true, indicates the endpoint has been closed for
+ // reading.
+ RcvClosed bool
+}
+
+// TCPSndBufState contains information about the state of an endpoint's send
+// socket buffer.
+//
+// +stateify savable
+type TCPSndBufState struct {
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+}
+
+// TCPEndpointStateInner contains the members of TCPEndpointState used directly
+// (that is, not within another containing struct) within the endpoint's
+// internal implementation.
+//
+// +stateify savable
+type TCPEndpointStateInner struct {
+ // TSOffset is a randomized offset added to the value of the TSVal
+ // field in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field
+ // is updated if required when a new segment is received by this
+ // endpoint.
+ RecentTS uint32
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+//
+// +stateify savable
+type TCPEndpointState struct {
+ // TCPEndpointStateInner contains the members of TCPEndpointState used
+ // by the endpoint's internal implementation.
+ TCPEndpointStateInner
+
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time `state:".(unixTime)"`
+
+ // RcvBufState contains information about the state of the endpoint's
+ // receive socket buffer.
+ RcvBufState TCPRcvBufState
+
+ // SndBufState contains information about the state of the endpoint's
+ // send socket buffer.
+ SndBufState TCPSndBufState
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // Receiver holds variables related to the TCP receiver for the
+ // endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e188efccb..80ad1a9d4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
return eps
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
+// handlePacket is called by the stack when new packets arrive to this transport
+// endpoint. It returns false if the packet could not be matched to any
+// transport endpoint, true otherwise.
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
- return
+ return false
}
}
@@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
- return
+ return true
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
- return
+ return true
}
transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+ return true
}
// handleError delivers an error to the transport endpoint identified by id.
@@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber,
}
return false
}
- ep.handlePacket(id, pkt)
- return true
+ return ep.handlePacket(id, pkt)
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 054cced0c..0adedd7c0 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -70,7 +70,7 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint {
ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()}
- ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return ep
}
@@ -233,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
peerAddr: route.RemoteAddress(),
route: route,
}
- ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
f.acceptQueue = append(f.acceptQueue, ep)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 87ea09a5e..0ba71b62e 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -691,10 +691,6 @@ const (
// number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption
- // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
- // specify the receive buffer size option.
- ReceiveBufferSizeOption
-
// SendQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
@@ -786,6 +782,13 @@ func (*TCPRecovery) isGettableTransportProtocolOption() {}
func (*TCPRecovery) isSettableTransportProtocolOption() {}
+// TCPAlwaysUseSynCookies indicates unconditional usage of syncookies.
+type TCPAlwaysUseSynCookies bool
+
+func (*TCPAlwaysUseSynCookies) isGettableTransportProtocolOption() {}
+
+func (*TCPAlwaysUseSynCookies) isSettableTransportProtocolOption() {}
+
const (
// TCPRACKLossDetection indicates RACK is used for loss detection and
// recovery.
@@ -1020,19 +1023,6 @@ func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {}
func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {}
-// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
-// the number of endpoints that can be in SYN-RCVD state before the stack
-// switches to using SYN cookies.
-type TCPSynRcvdCountThresholdOption uint64
-
-func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {}
-
// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide
// default for number of times SYN is retransmitted before aborting a connect.
type TCPSynRetriesOption uint8
@@ -1150,6 +1140,19 @@ type SendBufferSizeOption struct {
Max int
}
+// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ // Min is the minimum size for send buffer.
+ Min int
+
+ // Default is the default size for send buffer.
+ Default int
+
+ // Max is the maximum size for send buffer.
+ Max int
+}
+
// GetSendBufferLimits is used to get the send buffer size limits.
type GetSendBufferLimits func(StackHandler) SendBufferSizeOption
@@ -1162,6 +1165,18 @@ func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption {
return ss
}
+// GetReceiveBufferLimits is used to get the send buffer size limits.
+type GetReceiveBufferLimits func(StackHandler) ReceiveBufferSizeOption
+
+// GetStackReceiveBufferLimits is used to get default, min and max send buffer size.
+func GetStackReceiveBufferLimits(so StackHandler) ReceiveBufferSizeOption {
+ var ss ReceiveBufferSizeOption
+ if err := so.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ return ss
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
@@ -1218,7 +1233,7 @@ func (s *StatCounter) Decrement() {
}
// Value returns the current value of the counter.
-func (s *StatCounter) Value() uint64 {
+func (s *StatCounter) Value(name ...string) uint64 {
return atomic.LoadUint64(&s.count)
}
@@ -1562,6 +1577,10 @@ type IPStats struct {
// chain.
IPTablesOutputDropped *StatCounter
+ // IPTablesPostroutingDropped is the number of IP packets dropped in the
+ // Postrouting chain.
+ IPTablesPostroutingDropped *StatCounter
+
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
// of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.
@@ -1734,6 +1753,10 @@ type TCPStats struct {
// ChecksumErrors is the number of segments dropped due to bad checksums.
ChecksumErrors *StatCounter
+
+ // FailedPortReservations is the number of times TCP failed to reserve
+ // a port.
+ FailedPortReservations *StatCounter
}
// UDPStats collects UDP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 3cc8c36f1..d4f7bb5ff 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -9,11 +9,14 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -78,6 +81,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
@@ -101,6 +105,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -123,6 +128,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index d10ae05c2..dbd279c94 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -21,11 +21,14 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -312,3 +315,194 @@ func TestForwarding(t *testing.T) {
})
}
}
+
+func TestMulticastForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ ttl = 64
+ )
+
+ var (
+ ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
+ ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+
+ ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
+ ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+ )
+
+ rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+ }
+
+ rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+ }
+
+ v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo)))
+ }
+
+ v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+ }
+
+ tests := []struct {
+ name string
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ expectForward bool
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 link-local multicast destination",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4LinkLocalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 link-local source",
+ srcAddr: ipv4LinkLocalUnicastAddr,
+ dstAddr: utils.RemoteIPv4Addr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 link-local destination",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4LinkLocalUnicastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 non-link-local unicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 non-link-local multicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4GlobalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ },
+ },
+
+ {
+ name: "IPv6 link-local multicast destination",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6LinkLocalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 link-local source",
+ srcAddr: ipv6LinkLocalUnicastAddr,
+ dstAddr: utils.RemoteIPv6Addr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 link-local destination",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6LinkLocalUnicastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 non-link-local unicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 non-link-local multicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6GlobalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ }
+
+ if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID2,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID2,
+ },
+ })
+
+ test.rx(e1, test.srcAddr, test.dstAddr)
+
+ p, ok := e2.Read()
+ if ok != test.expectForward {
+ t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward)
+ }
+
+ if test.expectForward {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 2c538a43e..b04169751 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -510,25 +511,25 @@ func TestExternalLoopbackTraffic(t *testing.T) {
nicID1 = 1
nicID2 = 2
- ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
-
numPackets = 1
+ ttl = 64
)
+ ipv4Loopback := testutil.MustParse4("127.0.0.1")
loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address)
+ utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl)
}
loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address)
+ utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl)
}
loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback)
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl)
}
loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback)
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl)
}
invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index c6a9c2393..2d0a6e6a7 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -43,12 +44,15 @@ const (
// to a multicast or broadcast address uses a unicast source address for the
// reply.
func TestPingMulticastBroadcast(t *testing.T) {
- const nicID = 1
+ const (
+ nicID = 1
+ ttl = 64
+ )
tests := []struct {
name string
protoNum tcpip.NetworkProtocolNumber
- rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8)
srcAddr tcpip.Address
dstAddr tcpip.Address
expectedSrc tcpip.Address
@@ -136,7 +140,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
},
})
- test.rxICMP(e, test.srcAddr, test.dstAddr)
+ test.rxICMP(e, test.srcAddr, test.dstAddr, ttl)
pkt, ok := e.Read()
if !ok {
t.Fatal("expected ICMP response")
@@ -435,10 +439,10 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
// interested endpoints.
func TestReuseAddrAndBroadcast(t *testing.T) {
const (
- nicID = 1
- localPort = 9000
- loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ nicID = 1
+ localPort = 9000
)
+ loopbackBroadcast := testutil.MustParse4("127.255.255.255")
tests := []struct {
name string
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 78244f4eb..ac3c703d4 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -40,13 +41,13 @@ import (
// This tests that a local route is created and packets do not leave the stack.
func TestLocalPing(t *testing.T) {
const (
- nicID = 1
- ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+ nicID = 1
// icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
// request/reply packets.
icmpDataOffset = 8
)
+ ipv4Loopback := testutil.MustParse4("127.0.0.1")
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index d1c9f3a94..8fd9be32b 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -48,10 +48,6 @@ const (
LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
)
-const (
- ttl = 255
-)
-
// Common IP addresses used by tests.
var (
Ipv4Addr = tcpip.AddressWithPrefix{
@@ -322,7 +318,7 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
@@ -347,7 +343,7 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
// the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD
new file mode 100644
index 000000000..472545a5d
--- /dev/null
+++ b/pkg/tcpip/testutil/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ testonly = True,
+ srcs = ["testutil.go"],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip"],
+)
+
+go_test(
+ name = "testutil_test",
+ srcs = ["testutil_test.go"],
+ library = ":testutil",
+ deps = ["//pkg/tcpip"],
+)
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
new file mode 100644
index 000000000..1aaed590f
--- /dev/null
+++ b/pkg/tcpip/testutil/testutil.go
@@ -0,0 +1,43 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil provides helper functions for netstack unit tests.
+package testutil
+
+import (
+ "fmt"
+ "net"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address.
+// Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes.
+func MustParse4(addr string) tcpip.Address {
+ ip := net.ParseIP(addr).To4()
+ if ip == nil {
+ panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr))
+ }
+ return tcpip.Address(ip)
+}
+
+// MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing
+// an IPv4 address will yield an IPv4-mapped IPv6 address.
+func MustParse6(addr string) tcpip.Address {
+ ip := net.ParseIP(addr).To16()
+ if ip == nil {
+ panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr))
+ }
+ return tcpip.Address(ip)
+}
diff --git a/pkg/tcpip/testutil/testutil_test.go b/pkg/tcpip/testutil/testutil_test.go
new file mode 100644
index 000000000..6aad9585d
--- /dev/null
+++ b/pkg/tcpip/testutil/testutil_test.go
@@ -0,0 +1,103 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// Who tests the testutils?
+
+func TestMustParse4(t *testing.T) {
+ tcs := []struct {
+ str string
+ addr tcpip.Address
+ shouldPanic bool
+ }{
+ {
+ str: "127.0.0.1",
+ addr: "\x7f\x00\x00\x01",
+ }, {
+ str: "",
+ shouldPanic: true,
+ }, {
+ str: "fe80::1",
+ shouldPanic: true,
+ }, {
+ // In an ideal world this panics too, but net.IP
+ // doesn't distinguish between IPv4 and IPv4-mapped
+ // addresses.
+ str: "::ffff:0.0.0.1",
+ addr: "\x00\x00\x00\x01",
+ },
+ }
+
+ for _, tc := range tcs {
+ t.Run(tc.str, func(t *testing.T) {
+ if tc.shouldPanic {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("panic expected, but did not occur")
+ }
+ }()
+ }
+ if got := MustParse4(tc.str); got != tc.addr {
+ t.Errorf("got MustParse4(%s) = %s, want = %s", tc.str, got, tc.addr)
+ }
+ })
+ }
+}
+
+func TestMustParse6(t *testing.T) {
+ tcs := []struct {
+ str string
+ addr tcpip.Address
+ shouldPanic bool
+ }{
+ {
+ // In an ideal world this panics too, but net.IP
+ // doesn't distinguish between IPv4 and IPv4-mapped
+ // addresses.
+ str: "127.0.0.1",
+ addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x7f\x00\x00\x01",
+ }, {
+ str: "",
+ shouldPanic: true,
+ }, {
+ str: "fe80::1",
+ addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ }, {
+ str: "::ffff:0.0.0.1",
+ addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01",
+ },
+ }
+
+ for _, tc := range tcs {
+ t.Run(tc.str, func(t *testing.T) {
+ if tc.shouldPanic {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("panic expected, but did not occur")
+ }
+ }()
+ }
+ if got := MustParse6(tc.str); got != tc.addr {
+ t.Errorf("got MustParse6(%s) = %s, want = %s", tc.str, got, tc.addr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 50991c3c0..33ed78f54 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -63,12 +63,11 @@ type endpoint struct {
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvReady bool
- rcvList icmpPacketList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList icmpPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
@@ -84,6 +83,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
@@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
NetProto: netProto,
TransProto: transProto,
},
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
- state: stateInitial,
- uniqueID: s.UniqueID(),
+ waiterQueue: waiterQueue,
+ state: stateInitial,
+ uniqueID: s.UniqueID(),
}
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
+ ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
+ var rs tcpip.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
+ }
return ep, nil
}
@@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.rcvMu.Lock()
v := int(e.ttl)
@@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index a3c6db5a8..28a56a2d5 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) {
p.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after savercvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
if e.state != stateBound && e.state != stateConnected {
return
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 52ed9560c..496eca581 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -72,11 +72,10 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by mu.
mu sync.RWMutex `state:"nosave"`
@@ -91,6 +90,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// NewEndpoint returns a new packet endpoint.
@@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
},
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
}
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
+ ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- ep.rcvBufSizeMax = rs.Default
+ ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
@@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := ep.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
- }
- if v > rs.Max {
- v = rs.Max
- }
- if v < rs.Min {
- v = rs.Min
- }
- ep.rcvMu.Lock()
- ep.rcvBufSizeMax = v
- ep.rcvMu.Unlock()
- return nil
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
+ return &tcpip.ErrUnknownProtocolOption{}
}
func (ep *endpoint) LastError() tcpip.Error {
@@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
ep.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- v := ep.rcvBufSizeMax
- ep.rcvMu.Unlock()
- return v, nil
-
default:
return -1, &tcpip.ErrUnknownProtocolOption{}
}
@@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
return
}
- if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ rcvBufSize := ep.ops.GetReceiveBufferSize()
+ if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (ep *endpoint) freeze() {
+ ep.mu.Lock()
+ ep.frozen = true
+ ep.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (ep *endpoint) thaw() {
+ ep.mu.Lock()
+ ep.frozen = false
+ ep.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index ece662c0d..5bd860d20 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) {
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after saveRcvBufSizeMax(), which would have
- // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- ep.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) saveRcvBufSizeMax() int {
- max := ep.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- ep.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- ep.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) loadRcvBufSizeMax(max int) {
- ep.rcvBufSizeMax = max
+ ep.freeze()
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
+ ep.thaw()
ep.stack = stack.StackFromEnv
- ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
// TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index e27a249cd..10453a42a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,7 +26,6 @@
package raw
import (
- "fmt"
"io"
"gvisor.dev/gvisor/pkg/sync"
@@ -69,11 +68,10 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList rawPacketList
- rcvBufSize int
- rcvBufSizeMax int `state:".(int)"`
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList rawPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by mu.
mu sync.RWMutex `state:"nosave"`
@@ -89,6 +87,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
NetProto: netProto,
TransProto: transProto,
},
- waiterQueue: waiterQueue,
- rcvBufSizeMax: 32 * 1024,
- associated: associated,
+ waiterQueue: waiterQueue,
+ associated: associated,
}
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
+ e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- e.rcvBufSizeMax = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
// Unassociated endpoints are write-only and users call Write() with IP
// headers included. Because they're write-only, We don't need to
// register with the stack.
if !associated {
- e.rcvBufSizeMax = 0
+ e.ops.SetReceiveBufferSize(0, false)
e.waiterQueue = nil
return e, nil
}
@@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := e.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
- }
- if v > rs.Max {
- v = rs.Max
- }
- if v < rs.Min {
- v = rs.Min
- }
- e.rcvMu.Lock()
- e.rcvBufSizeMax = v
- e.rcvMu.Unlock()
- return nil
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
+ return &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
default:
return -1, &tcpip.ErrUnknownProtocolOption{}
}
@@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
@@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 263ec5146..5d6f2709c 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) {
p.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after saveRcvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
// If the endpoint is connected, re-connect.
if e.connected {
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index a69d6624d..48417f192 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -34,14 +34,12 @@ go_library(
"connect.go",
"connect_unsafe.go",
"cubic.go",
- "cubic_state.go",
"dispatcher.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
"protocol.go",
"rack.go",
- "rack_state.go",
"rcv.go",
"rcv_state.go",
"reno.go",
@@ -107,6 +105,7 @@ go_test(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/test/testutil",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 025b134e2..d4bd4e80e 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -23,7 +23,6 @@ import (
"sync/atomic"
"time"
- "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -51,11 +50,6 @@ const (
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
-
- // SynRcvdCountThreshold is the default global maximum number of
- // connections that are allowed to be in SYN-RCVD state before TCP
- // starts using SYN cookies to accept connections.
- SynRcvdCountThreshold uint64 = 1000
)
var (
@@ -80,9 +74,6 @@ func encodeMSS(mss uint16) uint32 {
type listenContext struct {
stack *stack.Stack
- // synRcvdCount is a reference to the stack level synRcvdCount.
- synRcvdCount *synRcvdCounter
-
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
rcvWnd seqnum.Size
@@ -138,14 +129,12 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
- p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
- if !ok {
- panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
- }
- l.synRcvdCount = p.SynRcvdCounter()
- rand.Read(l.nonce[0][:])
- rand.Read(l.nonce[1][:])
+ for i := range l.nonce {
+ if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil {
+ panic(err)
+ }
+ }
return l
}
@@ -163,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc
// Feed everything to the hasher.
l.hasherMu.Lock()
l.hasher.Reset()
+
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
l.hasher.Write(payload[:])
l.hasher.Write(l.nonce[nonceIndex][:])
- io.WriteString(l.hasher, string(id.LocalAddress))
- io.WriteString(l.hasher, string(id.RemoteAddress))
+ l.hasher.Write([]byte(id.LocalAddress))
+ l.hasher.Write([]byte(id.RemoteAddress))
// Finalize the calculation of the hash and return the first 4 bytes.
- h := make([]byte, 0, sha1.Size)
- h = l.hasher.Sum(h)
+ h := l.hasher.Sum(nil)
l.hasherMu.Unlock()
return binary.BigEndian.Uint32(h[:])
@@ -199,9 +191,17 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
+func (l *listenContext) useSynCookies() bool {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
+}
+
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -215,11 +215,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n := newEndpoint(l.stack, netProto, queue)
n.ops.SetV6Only(l.v6Only)
- n.ID = s.id
+ n.TransportEndpointInfo.ID = s.id
n.boundNICID = s.nicID
n.route = route
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto}
- n.rcvBufSize = int(l.rcvWnd)
+ n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */)
n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
@@ -231,7 +231,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// Bootstrap the auto tuning algorithm. Starting at zero will result in
// a large step function on the first window adjustment causing the
// window to grow to a really large value.
- n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
+ n.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = n.initialReceiveWindow()
return n, nil
}
@@ -248,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
}
@@ -290,7 +290,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
}
// Register new endpoint so that packets are routed to it.
- if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ if err := ep.stack.RegisterTransportEndpoint(
+ ep.effectiveNetProtos,
+ ProtocolNumber,
+ ep.TransportEndpointInfo.ID,
+ ep,
+ ep.boundPortFlags,
+ ep.boundBindToDevice,
+ ); err != nil {
ep.mu.Unlock()
ep.Close()
@@ -307,6 +314,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// Initialize and start the handshake.
h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+ h.listenEP = l.listenEP
h.start()
return h, nil
}
@@ -334,14 +342,14 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions,
func (l *listenContext) addPendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- l.pendingEndpoints[n.ID] = n
+ l.pendingEndpoints[n.TransportEndpointInfo.ID] = n
l.pending.Add(1)
l.pendingMu.Unlock()
}
func (l *listenContext) removePendingEndpoint(n *endpoint) {
l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.ID)
+ delete(l.pendingEndpoints, n.TransportEndpointInfo.ID)
l.pending.Done()
l.pendingMu.Unlock()
}
@@ -382,39 +390,46 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
// scaling.
- e.rcv.rcvWndScale = e.h.effectiveRcvWndScale()
+ e.rcv.RcvWndScale = e.h.effectiveRcvWndScale()
// Clean up handshake state stored in the endpoint so that it can be GCed.
e.h = nil
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state (acceptedChan is nil),
-// the new endpoint is closed instead.
+// listener has transitioned out of the listen state (accepted is the zero
+// value), the new endpoint is reset instead.
func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Lock()
e.pendingAccepted.Add(1)
e.mu.Unlock()
defer e.pendingAccepted.Done()
- e.acceptMu.Lock()
- for {
- if e.acceptedChan == nil {
- e.acceptMu.Unlock()
- n.notifyProtocolGoroutine(notifyReset)
- return
- }
- select {
- case e.acceptedChan <- n:
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ if e.accepted == (accepted{}) {
+ return false
+ }
+ if e.accepted.endpoints.Len() == e.accepted.cap {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.accepted.endpoints.PushBack(n)
if !withSynCookie {
atomic.AddInt32(&e.synRcvdCount, -1)
}
- e.acceptMu.Unlock()
- e.waiterQueue.Notify(waiter.ReadableEvents)
- return
- default:
- e.acceptCond.Wait()
+ return true
}
+ }()
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ n.notifyProtocolGoroutine(notifyReset)
}
}
@@ -436,17 +451,21 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// * propagateInheritableOptionsLocked has been called.
// * e.mu is held.
func (e *endpoint) reserveTupleLocked() bool {
- dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ dest := tcpip.FullAddress{
+ Addr: e.TransportEndpointInfo.ID.RemoteAddress,
+ Port: e.TransportEndpointInfo.ID.RemotePort,
+ }
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: dest,
}
if !e.stack.ReserveTuple(portRes) {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return false
}
@@ -485,7 +504,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
go func() {
- defer ctx.synRcvdCount.dec()
if err := h.complete(); err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -497,24 +515,29 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(h.ep, false /*withSynCookie*/)
- }() // S/R-SAFE: synRcvdCount is the barrier.
+ }()
return nil
}
-func (e *endpoint) incSynRcvdCount() bool {
+func (e *endpoint) synRcvdBacklogFull() bool {
e.acceptMu.Lock()
- canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan)
+ acceptedCap := e.accepted.cap
e.acceptMu.Unlock()
- if canInc {
- atomic.AddInt32(&e.synRcvdCount, 1)
- }
- return canInc
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ //
+ // We maintain an equality check here as the synRcvdCount is incremented
+ // and compared only from a single listener context and the capacity of
+ // the accepted queue can only increase by a new listen call.
+ return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan)
+ full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
e.acceptMu.Unlock()
return full
}
@@ -524,9 +547,9 @@ func (e *endpoint) acceptQueueIsFull() bool {
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
- e.rcvListMu.Lock()
- rcvClosed := e.rcvClosed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ rcvClosed := e.rcvQueueInfo.RcvClosed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
// If the endpoint is shutdown, reply with reset.
//
@@ -538,69 +561,55 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
switch {
case s.flags == header.TCPFlagSyn:
- opts := parseSynSegmentOptions(s)
- if ctx.synRcvdCount.inc() {
- // Only handle the syn if the following conditions hold
- // - accept queue is not full.
- // - number of connections in synRcvd state is less than the
- // backlog.
- if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
- s.incRef()
- _ = e.handleSynSegment(ctx, s, &opts)
- return nil
- }
- ctx.synRcvdCount.dec()
+ if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return nil
- } else {
- // If cookies are in use but the endpoint accept queue
- // is full then drop the syn.
- if e.acceptQueueIsFull() {
- e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
- cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ }
- route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer route.Release()
+ opts := parseSynSegmentOptions(s)
+ if !ctx.useSynCookies() {
+ s.incRef()
+ atomic.AddInt32(&e.synRcvdCount, 1)
+ return e.handleSynSegment(ctx, s, &opts)
+ }
+ route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
- // Send SYN without window scaling because we currently
- // don't encode this information in the cookie.
- //
- // Enable Timestamp option if the original syn did have
- // the timestamp option specified.
- //
- // Use the user supplied MSS on the listening socket for
- // new connections, if available.
- synOpts := header.TCPSynOptions{
- WS: -1,
- TS: opts.TS,
- TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
- TSEcr: opts.TSVal,
- MSS: calculateAdvertisedMSS(e.userMSS, route),
- }
- fields := tcpFields{
- id: s.id,
- ttl: e.ttl,
- tos: e.sendTOS,
- flags: header.TCPFlagSyn | header.TCPFlagAck,
- seq: cookie,
- ack: s.sequenceNumber + 1,
- rcvWnd: ctx.rcvWnd,
- }
- if err := e.sendSynTCP(route, fields, synOpts); err != nil {
- return err
- }
- e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
- return nil
+ // Send SYN without window scaling because we currently
+ // don't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ //
+ // Use the user supplied MSS on the listening socket for
+ // new connections, if available.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
+ TSEcr: opts.TSVal,
+ MSS: calculateAdvertisedMSS(e.userMSS, route),
+ }
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ fields := tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagSyn | header.TCPFlagAck,
+ seq: cookie,
+ ack: s.sequenceNumber + 1,
+ rcvWnd: ctx.rcvWnd,
+ }
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
+ return err
}
+ e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ return nil
case (s.flags & header.TCPFlagAck) != 0:
if e.acceptQueueIsFull() {
@@ -615,25 +624,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
- if !ctx.synRcvdCount.synCookiesInUse() {
- // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
- // Any acknowledgment is bad if it arrives on a connection still in
- // the LISTEN state. An acceptable reset segment should be formed
- // for any arriving ACK-bearing segment. The RST should be
- // formatted as follows:
- //
- // <SEQ=SEG.ACK><CTL=RST>
- //
- // Send a reset as this is an ACK for which there is no
- // half open connections and we are not using cookies
- // yet.
- //
- // The only time we should reach here when a connection
- // was opened and closed really quickly and a delayed
- // ACK was received from the sender.
- return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -651,7 +641,23 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return nil
+
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
@@ -672,7 +678,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
return err
}
@@ -693,7 +699,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(
+ n.effectiveNetProtos,
+ ProtocolNumber,
+ n.TransportEndpointInfo.ID,
+ n,
+ n.boundPortFlags,
+ n.boundBindToDevice,
+ ); err != nil {
n.mu.Unlock()
n.Close()
@@ -708,7 +721,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// endpoint as the Timestamp was already
// randomly offset when the original SYN-ACK was
// sent above.
- n.tsOffset = 0
+ n.TSOffset = 0
// Switch state to connected.
n.isConnectNotified = true
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index a9e978cf6..7bc6b08f0 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -65,11 +65,12 @@ const (
// NOTE: handshake.ep.mu is held during handshake processing. It is released if
// we are going to block and reacquired when we start processing an event.
type handshake struct {
- ep *endpoint
- state handshakeState
- active bool
- flags header.TCPFlags
- ackNum seqnum.Value
+ ep *endpoint
+ listenEP *endpoint
+ state handshakeState
+ active bool
+ flags header.TCPFlags
+ ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
iss seqnum.Value
@@ -155,7 +156,7 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed())
}
// generateSecureISN generates a secure Initial Sequence number based on the
@@ -301,7 +302,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
ttl = h.ep.route.DefaultTTL()
}
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -357,14 +358,14 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
h.resetState()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
- TS: h.ep.sendTSOk,
+ TS: h.ep.SendTSOk,
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTimestamp(),
- SACKPermitted: h.ep.sackPermitted,
+ SACKPermitted: h.ep.SACKPermitted,
MSS: h.ep.amss,
}
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -389,13 +390,22 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
- if h.ep.sendTSOk && !s.parsedOptions.TS {
+ if h.ep.SendTSOk && !s.parsedOptions.TS {
h.ep.stack.Stats().DroppedPackets.Increment()
return nil
}
+ // Drop the ACK if the accept queue is full.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_ipv4.c#L1523
+ // We could abort the connection as well with a tunable as in
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_minisocks.c#L788
+ if listenEP := h.listenEP; listenEP != nil && listenEP.acceptQueueIsFull() {
+ listenEP.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// Update timestamp if required. See RFC7323, section-4.3.
- if h.ep.sendTSOk && s.parsedOptions.TS {
+ if h.ep.SendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
@@ -485,8 +495,8 @@ func (h *handshake) start() {
// start() is also called in a listen context so we want to make sure we only
// send the TS/SACK option when we received the TS/SACK in the initial SYN.
if h.state == handshakeSynRcvd {
- synOpts.TS = h.ep.sendTSOk
- synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ synOpts.TS = h.ep.SendTSOk
+ synOpts.SACKPermitted = h.ep.SACKPermitted && bool(sackEnabled)
if h.sndWndScale < 0 {
// Disable window scaling if the peer did not send us
// the window scaling option.
@@ -496,7 +506,7 @@ func (h *handshake) start() {
h.sendSYNOpts = synOpts
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -544,7 +554,7 @@ func (h *handshake) complete() tcpip.Error {
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
h.ep.sendSynTCP(h.ep.route, tcpFields{
- id: h.ep.ID,
+ id: h.ep.TransportEndpointInfo.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
flags: h.flags,
@@ -845,7 +855,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// N.B. the ordering here matches the ordering used by Linux internally
// and described in the raw makeOptions function. We don't include
// unnecessary cases here (post connection.)
- if e.sendTSOk {
+ if e.SendTSOk {
// Embed the timestamp if timestamp has been enabled.
//
// We only use the lower 32 bits of the unix time in
@@ -862,7 +872,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
}
- if e.sackPermitted && len(sackBlocks) > 0 {
+ if e.SACKPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
@@ -884,7 +894,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
}
options := e.makeOptions(sackBlocks)
err := e.sendTCP(e.route, tcpFields{
- id: e.ID,
+ id: e.TransportEndpointInfo.ID,
ttl: e.ttl,
tos: e.sendTOS,
flags: flags,
@@ -898,9 +908,9 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
}
func (e *endpoint) handleWrite() {
- e.sndBufMu.Lock()
+ e.sndQueueInfo.sndQueueMu.Lock()
next := e.drainSendQueueLocked()
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.sendData(next)
}
@@ -909,10 +919,10 @@ func (e *endpoint) handleWrite() {
//
// Precondition: e.sndBufMu must be locked.
func (e *endpoint) drainSendQueueLocked() *segment {
- first := e.sndQueue.Front()
+ first := e.sndQueueInfo.sndQueue.Front()
if first != nil {
- e.snd.writeList.PushBackList(&e.sndQueue)
- e.sndBufInQueue = 0
+ e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue)
+ e.sndQueueInfo.SndBufInQueue = 0
}
return first
}
@@ -936,7 +946,7 @@ func (e *endpoint) handleClose() {
e.handleWrite()
// Mark send side as closed.
- e.snd.closed = true
+ e.snd.Closed = true
}
// resetConnectionLocked puts the endpoint in an error state with the given
@@ -958,12 +968,12 @@ func (e *endpoint) resetConnectionLocked(err tcpip.Error) {
//
// See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
// information.
- sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ sndWndEnd := e.snd.SndUna.Add(e.snd.SndWnd)
resetSeqNum := sndWndEnd
- if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
- resetSeqNum = e.snd.sndNxt
+ if !sndWndEnd.LessThan(e.snd.SndNxt) || e.snd.SndNxt.Size(sndWndEnd) < (1<<e.snd.SndWndScale) {
+ resetSeqNum = e.snd.SndNxt
}
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.RcvNxt, 0)
}
}
@@ -989,13 +999,13 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- e.rcvListMu.Lock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
// Bootstrap the auto tuning algorithm. Starting at zero will
// result in a really large receive window after the first auto
// tuning adjustment.
- e.rcvAutoParams.prevCopied = int(h.rcvWnd)
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
e.setEndpointState(StateEstablished)
}
@@ -1026,10 +1036,15 @@ func (e *endpoint) transitionToStateCloseLocked() {
// only when the endpoint is in StateClose and we want to deliver the segment
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
- ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.TransportEndpointInfo.ID, s.nicID)
if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
- ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
+ ep = e.stack.FindTransportEndpoint(
+ header.IPv4ProtocolNumber,
+ e.TransProto,
+ e.TransportEndpointInfo.ID,
+ s.nicID,
+ )
}
if ep == nil {
replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */)
@@ -1108,7 +1123,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
}
// handleSegments processes all inbound segments.
-func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
+//
+// Precondition: e.mu must be held.
+func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
if e.EndpointState().closed() {
@@ -1120,7 +1137,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
break
}
- cont, err := e.handleSegment(s)
+ cont, err := e.handleSegmentLocked(s)
s.decRef()
if err != nil {
return err
@@ -1138,7 +1155,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
}
// Send an ACK for all processed packets if needed.
- if e.rcv.rcvNxt != e.snd.maxSentAck {
+ if e.rcv.RcvNxt != e.snd.MaxSentAck {
e.snd.sendAck()
}
@@ -1147,18 +1164,21 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
return nil
}
-func (e *endpoint) probeSegment() {
- if e.probe != nil {
- e.probe(e.completeState())
+// Precondition: e.mu must be held.
+func (e *endpoint) probeSegmentLocked() {
+ if fn := e.probe; fn != nil {
+ fn(e.completeStateLocked())
}
}
// handleSegment handles a given segment and notifies the worker goroutine if
// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) {
+//
+// Precondition: e.mu must be held.
+func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) {
// Invoke the tcp probe if installed. The tcp probe function will update
// the TCPEndpointState after the segment is processed.
- defer e.probeSegment()
+ defer e.probeSegmentLocked()
if s.flagIsSet(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
@@ -1191,7 +1211,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) {
} else if s.flagIsSet(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
- s.window <<= e.snd.sndWndScale
+ s.window <<= e.snd.SndWndScale
// RFC 793, page 41 states that "once in the ESTABLISHED
// state all segments must carry current acknowledgment
@@ -1255,7 +1275,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error {
// seg.seq = snd.nxt-1.
e.keepalive.unacked++
e.keepalive.Unlock()
- e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
+ e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.SndNxt-1)
e.resetKeepaliveTimer(false)
return nil
}
@@ -1269,7 +1289,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
}
// Start the keepalive timer IFF it's enabled and there is no pending
// data to send.
- if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.SndUna != e.snd.SndNxt {
e.keepalive.timer.disable()
e.keepalive.Unlock()
return
@@ -1362,14 +1382,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
f func() tcpip.Error
}{
{
- w: &e.sndWaker,
+ w: &e.sndQueueInfo.sndWaker,
f: func() tcpip.Error {
e.handleWrite()
return nil
},
},
{
- w: &e.sndCloseWaker,
+ w: &e.sndQueueInfo.sndCloseWaker,
f: func() tcpip.Error {
e.handleClose()
return nil
@@ -1403,7 +1423,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
{
w: &e.newSegmentWaker,
f: func() tcpip.Error {
- return e.handleSegments(false /* fastPath */)
+ return e.handleSegmentsLocked(false /* fastPath */)
},
},
{
@@ -1419,11 +1439,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if n&notifyMTUChanged != 0 {
- e.sndBufMu.Lock()
- count := e.packetTooBigCount
- e.packetTooBigCount = 0
- mtu := e.sndMTU
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ count := e.sndQueueInfo.PacketTooBigCount
+ e.sndQueueInfo.PacketTooBigCount = 0
+ mtu := e.sndQueueInfo.SndMTU
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.snd.updateMaxPayloadSize(mtu, count)
}
@@ -1453,7 +1473,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
- if err := e.handleSegments(false /* fastPath */); err != nil {
+ if err := e.handleSegmentsLocked(false /* fastPath */); err != nil {
return err
}
}
@@ -1504,11 +1524,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.newSegmentWaker.Assert()
}
- e.rcvListMu.Lock()
- if !e.rcvList.Empty() {
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ if !e.rcvQueueInfo.rcvQueue.Empty() {
e.waiterQueue.Notify(waiter.ReadableEvents)
}
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index 1975f1a44..962f1d687 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -17,6 +17,8 @@ package tcp
import (
"math"
"time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// cubicState stores the variables related to TCP CUBIC congestion
@@ -25,47 +27,12 @@ import (
// See: https://tools.ietf.org/html/rfc8312.
// +stateify savable
type cubicState struct {
- // wLastMax is the previous wMax value.
- wLastMax float64
-
- // wMax is the value of the congestion window at the
- // time of last congestion event.
- wMax float64
-
- // t denotes the time when the current congestion avoidance
- // was entered.
- t time.Time `state:".(unixTime)"`
+ stack.TCPCubicState
// numCongestionEvents tracks the number of congestion events since last
// RTO.
numCongestionEvents int
- // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
- // per RFC.
- c float64
-
- // k is the time period that the above function takes to increase the
- // current window size to W_max if there are no further congestion
- // events and is calculated using the following equation:
- //
- // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
- k float64
-
- // beta is the CUBIC multiplication decrease factor. that is, when a
- // congestion event is detected, CUBIC reduces its cwnd to
- // W_cubic(0)=W_max*beta_cubic.
- beta float64
-
- // wC is window computed by CUBIC at time t. It's calculated using the
- // formula:
- //
- // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
- wC float64
-
- // wEst is the window computed by CUBIC at time t+RTT i.e
- // W_cubic(t+RTT).
- wEst float64
-
s *sender
}
@@ -73,10 +40,12 @@ type cubicState struct {
// beta and c set and t set to current time.
func newCubicCC(s *sender) *cubicState {
return &cubicState{
- t: time.Now(),
- beta: 0.7,
- c: 0.4,
- s: s,
+ TCPCubicState: stack.TCPCubicState{
+ T: time.Now(),
+ Beta: 0.7,
+ C: 0.4,
+ },
+ s: s,
}
}
@@ -90,10 +59,10 @@ func (c *cubicState) enterCongestionAvoidance() {
// See: https://tools.ietf.org/html/rfc8312#section-4.7 &
// https://tools.ietf.org/html/rfc8312#section-4.8
if c.numCongestionEvents == 0 {
- c.k = 0
- c.t = time.Now()
- c.wLastMax = c.wMax
- c.wMax = float64(c.s.sndCwnd)
+ c.K = 0
+ c.T = time.Now()
+ c.WLastMax = c.WMax
+ c.WMax = float64(c.s.SndCwnd)
}
}
@@ -104,16 +73,16 @@ func (c *cubicState) enterCongestionAvoidance() {
func (c *cubicState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
- newcwnd := c.s.sndCwnd + packetsAcked
+ newcwnd := c.s.SndCwnd + packetsAcked
enterCA := false
- if newcwnd >= c.s.sndSsthresh {
- newcwnd = c.s.sndSsthresh
- c.s.sndCAAckCount = 0
+ if newcwnd >= c.s.Ssthresh {
+ newcwnd = c.s.Ssthresh
+ c.s.SndCAAckCount = 0
enterCA = true
}
- packetsAcked -= newcwnd - c.s.sndCwnd
- c.s.sndCwnd = newcwnd
+ packetsAcked -= newcwnd - c.s.SndCwnd
+ c.s.SndCwnd = newcwnd
if enterCA {
c.enterCongestionAvoidance()
}
@@ -124,49 +93,49 @@ func (c *cubicState) updateSlowStart(packetsAcked int) int {
// ACK received.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
func (c *cubicState) Update(packetsAcked int) {
- if c.s.sndCwnd < c.s.sndSsthresh {
+ if c.s.SndCwnd < c.s.Ssthresh {
packetsAcked = c.updateSlowStart(packetsAcked)
if packetsAcked == 0 {
return
}
} else {
c.s.rtt.Lock()
- srtt := c.s.rtt.srtt
+ srtt := c.s.rtt.TCPRTTState.SRTT
c.s.rtt.Unlock()
- c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
+ c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt)
}
}
// cubicCwnd computes the CUBIC congestion window after t seconds from last
// congestion event.
func (c *cubicState) cubicCwnd(t float64) float64 {
- return c.c*math.Pow(t, 3.0) + c.wMax
+ return c.C*math.Pow(t, 3.0) + c.WMax
}
// getCwnd returns the current congestion window as computed by CUBIC.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
- elapsed := time.Since(c.t).Seconds()
+ elapsed := time.Since(c.T).Seconds()
// Compute the window as per Cubic after 'elapsed' time
// since last congestion event.
- c.wC = c.cubicCwnd(elapsed - c.k)
+ c.WC = c.cubicCwnd(elapsed - c.K)
// Compute the TCP friendly estimate of the congestion window.
- c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+ c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds())
// Make sure in the TCP friendly region CUBIC performs at least
// as well as Reno.
- if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ if c.WC < c.WEst && float64(sndCwnd) < c.WEst {
// TCP Friendly region of cubic.
- return int(c.wEst)
+ return int(c.WEst)
}
// In Concave/Convex region of CUBIC, calculate what CUBIC window
// will be after 1 RTT and use that to grow congestion window
// for every ack.
- tEst := (time.Since(c.t) + srtt).Seconds()
- wtRtt := c.cubicCwnd(tEst - c.k)
+ tEst := (time.Since(c.T) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.K)
// As per 4.3 for each received ACK cwnd must be incremented
// by (w_cubic(t+RTT) - cwnd/cwnd.
cwnd := float64(sndCwnd)
@@ -182,9 +151,9 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int
func (c *cubicState) HandleLossDetected() {
// See: https://tools.ietf.org/html/rfc8312#section-4.5
c.numCongestionEvents++
- c.t = time.Now()
- c.wLastMax = c.wMax
- c.wMax = float64(c.s.sndCwnd)
+ c.T = time.Now()
+ c.WLastMax = c.WMax
+ c.WMax = float64(c.s.SndCwnd)
c.fastConvergence()
c.reduceSlowStartThreshold()
@@ -193,10 +162,10 @@ func (c *cubicState) HandleLossDetected() {
// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
func (c *cubicState) HandleRTOExpired() {
// See: https://tools.ietf.org/html/rfc8312#section-4.6
- c.t = time.Now()
+ c.T = time.Now()
c.numCongestionEvents = 0
- c.wLastMax = c.wMax
- c.wMax = float64(c.s.sndCwnd)
+ c.WLastMax = c.WMax
+ c.WMax = float64(c.s.SndCwnd)
c.fastConvergence()
@@ -206,29 +175,29 @@ func (c *cubicState) HandleRTOExpired() {
// Reduce the congestion window to 1, i.e., enter slow-start. Per
// RFC 5681, page 7, we must use 1 regardless of the value of the
// initial congestion window.
- c.s.sndCwnd = 1
+ c.s.SndCwnd = 1
}
// fastConvergence implements the logic for Fast Convergence algorithm as
// described in https://tools.ietf.org/html/rfc8312#section-4.6.
func (c *cubicState) fastConvergence() {
- if c.wMax < c.wLastMax {
- c.wLastMax = c.wMax
- c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ if c.WMax < c.WLastMax {
+ c.WLastMax = c.WMax
+ c.WMax = c.WMax * (1.0 + c.Beta) / 2.0
} else {
- c.wLastMax = c.wMax
+ c.WLastMax = c.WMax
}
// Recompute k as wMax may have changed.
- c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+ c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C)
}
// PostRecovery implemements congestionControl.PostRecovery.
func (c *cubicState) PostRecovery() {
- c.t = time.Now()
+ c.T = time.Now()
}
// reduceSlowStartThreshold returns new SsThresh as described in
// https://tools.ietf.org/html/rfc8312#section-4.7.
func (c *cubicState) reduceSlowStartThreshold() {
- c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+ c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0))
}
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 21162f01a..512053a04 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -116,7 +116,7 @@ func (p *processor) start(wg *sync.WaitGroup) {
if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
// If the endpoint is in a connected state then we do direct delivery
// to ensure low latency and avoid scheduler interactions.
- switch err := ep.handleSegments(true /* fastPath */); {
+ switch err := ep.handleSegmentsLocked(true /* fastPath */); {
case err != nil:
// Send any active resets if required.
ep.resetConnectionLocked(err)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index f6a16f96e..d6d68f128 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -565,17 +565,15 @@ func TestV4AcceptOnV4(t *testing.T) {
}
func testV4ListenClose(t *testing.T, c *context.Context) {
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- const n = uint16(32)
+ const n = 32
// Start listening.
- if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
+ if err := c.EP.Listen(n); err != nil {
t.Fatalf("Listen failed: %v", err)
}
@@ -591,9 +589,9 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
})
}
- // Each of these ACK's will cause a syn-cookie based connection to be
+ // Each of these ACKs will cause a syn-cookie based connection to be
// accepted and delivered to the listening endpoint.
- for i := uint16(0); i < n; i++ {
+ for i := 0; i < n; i++ {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss := seqnum.Value(tcp.SequenceNumber())
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index c5daba232..f25dc781a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "container/list"
"encoding/binary"
"fmt"
"io"
@@ -190,42 +191,6 @@ type SACKInfo struct {
NumBlocks int
}
-// rcvBufAutoTuneParams are used to hold state variables to compute
-// the auto tuned recv buffer size.
-//
-// +stateify savable
-type rcvBufAutoTuneParams struct {
- // measureTime is the time at which the current measurement
- // was started.
- measureTime time.Time `state:".(unixTime)"`
-
- // copied is the number of bytes copied out of the receive
- // buffers since this measure began.
- copied int
-
- // prevCopied is the number of bytes copied out of the receive
- // buffers in the previous RTT period.
- prevCopied int
-
- // rtt is the non-smoothed minimum RTT as measured by observing the time
- // between when a byte is first acknowledged and the receipt of data
- // that is at least one window beyond the sequence number that was
- // acknowledged.
- rtt time.Duration
-
- // rttMeasureSeqNumber is the highest acceptable sequence number at the
- // time this RTT measurement period began.
- rttMeasureSeqNumber seqnum.Value
-
- // rttMeasureTime is the absolute time at which the current rtt
- // measurement period began.
- rttMeasureTime time.Time `state:".(unixTime)"`
-
- // disabled is true if an explicit receive buffer is set for the
- // endpoint.
- disabled bool
-}
-
// ReceiveErrors collect segment receive errors within transport layer.
type ReceiveErrors struct {
tcpip.ReceiveErrors
@@ -246,7 +211,7 @@ type ReceiveErrors struct {
ListenOverflowAckDrop tcpip.StatCounter
// ZeroRcvWindowState is the number of times we advertised
- // a zero receive window when rcvList is full.
+ // a zero receive window when rcvQueue is full.
ZeroRcvWindowState tcpip.StatCounter
// WantZeroWindow is the number of times we wanted to advertise a
@@ -309,18 +274,45 @@ type Stats struct {
// marker interface.
func (*Stats) IsEndpointStats() {}
-// EndpointInfo holds useful information about a transport endpoint which
-// can be queried by monitoring tools. This exists to allow tcp-only state to
-// be exposed.
+// sndQueueInfo implements a send queue.
//
// +stateify savable
-type EndpointInfo struct {
- stack.TransportEndpointInfo
+type sndQueueInfo struct {
+ sndQueueMu sync.Mutex `state:"nosave"`
+ stack.TCPSndBufState
+
+ // sndQueue holds segments that are ready to be sent.
+ sndQueue segmentList `state:"wait"`
+
+ // sndWaker is used to signal the protocol goroutine when segments are
+ // added to the `sndQueue`.
+ sndWaker sleep.Waker `state:"manual"`
+
+ // sndCloseWaker is used to notify the protocol goroutine when the send
+ // side is closed.
+ sndCloseWaker sleep.Waker `state:"manual"`
}
-// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
-// marker interface.
-func (*EndpointInfo) IsEndpointInfo() {}
+// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata.
+//
+// +stateify savable
+type rcvQueueInfo struct {
+ rcvQueueMu sync.Mutex `state:"nosave"`
+ stack.TCPRcvBufState
+
+ // rcvQueue is the queue for ready-for-delivery segments. This struct's
+ // mutex must be held in order append segments to list.
+ rcvQueue segmentList `state:"wait"`
+}
+
+// +stateify savable
+type accepted struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
+ cap int
+}
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
@@ -337,9 +329,9 @@ func (*EndpointInfo) IsEndpointInfo() {}
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects acceptedChan.
-// e.rcvListMu -> Protects the rcvList and associated fields.
-// e.sndBufMu -> Protects the sndQueue and associated fields.
+// e.acceptMu -> protects accepted.
+// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
+// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
//
// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different
@@ -362,7 +354,8 @@ func (*EndpointInfo) IsEndpointInfo() {}
//
// +stateify savable
type endpoint struct {
- EndpointInfo
+ stack.TCPEndpointStateInner
+ stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// endpointEntry is used to queue endpoints for processing to the
@@ -395,38 +388,23 @@ type endpoint struct {
// rcvReadMu synchronizes calls to Read.
//
- // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // mu and rcvQueueMu are temporarily released during data copying. rcvReadMu
// must be held during each read to ensure atomicity, so that multiple reads
// do not interleave.
//
// rcvReadMu should be held before holding mu.
rcvReadMu sync.Mutex `state:"nosave"`
- // rcvListMu synchronizes access to rcvList.
- //
- // rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
-
- // rcvList is the queue for ready-for-delivery segments.
- //
- // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
- // and removing segments from list. A range of segment can be determined, then
- // temporarily release mu and rcvListMu while processing the segment range.
- // This allows new segments to be appended to the list while processing.
- //
- // rcvListMu must be held to append segments to list.
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- // rcvBufSize is the total size of the receive buffer.
- rcvBufSize int
- // rcvBufUsed is the actual number of payload bytes held in the receive buffer
- // not counting any overheads of the segments itself. NOTE: This will always
- // be strictly <= rcvMemUsed below.
- rcvBufUsed int
- rcvAutoParams rcvBufAutoTuneParams
+ // rcvQueueInfo holds the implementation of the endpoint's receive buffer.
+ // The data within rcvQueueInfo should only be accessed while rcvReadMu, mu,
+ // and rcvQueueMu are held, in that stated order. While processing the segment
+ // range, you can determine a range and then temporarily release mu and
+ // rcvQueueMu, which allows new segments to be appended to the queue while
+ // processing.
+ rcvQueueInfo rcvQueueInfo
// rcvMemUsed tracks the total amount of memory in use by received segments
- // held in rcvList, pendingRcvdSegments and the segment queue. This is used to
+ // held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to
// compute the window and the actual available buffer space. This is distinct
// from rcvBufUsed above which is the actual number of payload bytes held in
// the buffer not including any segment overheads.
@@ -488,33 +466,16 @@ type endpoint struct {
// also true, and they're both protected by the mutex.
workerCleanup bool
- // sendTSOk is used to indicate when the TS Option has been negotiated.
- // When sendTSOk is true every non-RST segment should carry a TS as per
- // RFC7323#section-1.1
- sendTSOk bool
-
- // recentTS is the timestamp that should be sent in the TSEcr field of
- // the timestamp for future segments sent by the endpoint. This field is
- // updated if required when a new segment is received by this endpoint.
- recentTS uint32
-
- // recentTSTime is the unix time when we updated recentTS last.
+ // recentTSTime is the unix time when we last updated
+ // TCPEndpointStateInner.RecentTS.
recentTSTime time.Time `state:".(unixTime)"`
- // tsOffset is a randomized offset added to the value of the
- // TSVal field in the timestamp option.
- tsOffset uint32
-
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
// tcpRecovery is the loss deteoction algorithm used by TCP.
tcpRecovery tcpip.TCPRecovery
- // sackPermitted is set to true if the peer sends the TCPSACKPermitted
- // option in the SYN/SYN-ACK.
- sackPermitted bool
-
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
@@ -550,32 +511,13 @@ type endpoint struct {
// this value.
windowClamp uint32
- // The following fields are used to manage the send buffer. When
- // segments are ready to be sent, they are added to sndQueue and the
- // protocol goroutine is signaled via sndWaker.
- //
- // When the send side is closed, the protocol goroutine is notified via
- // sndCloseWaker, and sndClosed is set to true.
- sndBufMu sync.Mutex `state:"nosave"`
- sndBufUsed int
- sndClosed bool
- sndBufInQueue seqnum.Size
- sndQueue segmentList `state:"wait"`
- sndWaker sleep.Waker `state:"manual"`
- sndCloseWaker sleep.Waker `state:"manual"`
+ // sndQueueInfo contains the implementation of the endpoint's send queue.
+ sndQueueInfo sndQueueInfo
// cc stores the name of the Congestion Control algorithm to use for
// this endpoint.
cc tcpip.CongestionControlOption
- // The following are used when a "packet too big" control packet is
- // received. They are protected by sndBufMu. They are used to
- // communicate to the main protocol goroutine how many such control
- // messages have been received since the last notification was processed
- // and what was the smallest MTU seen.
- packetTooBigCount int
- sndMTU int
-
// newSegmentWaker is used to indicate to the protocol goroutine that
// it needs to wake up and handle new segments queued to it.
newSegmentWaker sleep.Waker `state:"manual"`
@@ -607,33 +549,26 @@ type endpoint struct {
// listener.
deferAccept time.Duration
- // pendingAccepted is a synchronization primitive used to track number
- // of connections that are queued up to be delivered to the accepted
- // channel. We use this to ensure that all goroutines blocked on writing
- // to the acceptedChan below terminate before we close acceptedChan.
+ // pendingAccepted tracks connections queued to be accepted. It is used to
+ // ensure such queued connections are terminated before the accepted queue is
+ // marked closed (by setting its capacity to zero).
pendingAccepted sync.WaitGroup `state:"nosave"`
- // acceptMu protects acceptedChan.
+ // acceptMu protects accepted.
acceptMu sync.Mutex `state:"nosave"`
// acceptCond is a condition variable that can be used to block on when
- // acceptedChan is full and an endpoint is ready to be delivered.
- //
- // This condition variable is required because just blocking on sending
- // to acceptedChan does not work in cases where endpoint.Listen is
- // called twice with different backlog values. In such cases the channel
- // is closed and a new one created. Any pending goroutines blocking on
- // the write to the channel will panic.
+ // accepted is full and an endpoint is ready to be delivered.
//
// We use this condition variable to block/unblock goroutines which
// tried to deliver an endpoint but couldn't because accept backlog was
// full ( See: endpoint.deliverAccepted ).
acceptCond *sync.Cond `state:"nosave"`
- // acceptedChan is used by a listening endpoint protocol goroutine to
+ // accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- acceptedChan chan *endpoint `state:".([]*endpoint)"`
+ accepted accepted
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -779,7 +714,7 @@ func (e *endpoint) UnlockUser() {
switch e.EndpointState() {
case StateEstablished:
- if err := e.handleSegments(true /* fastPath */); err != nil {
+ if err := e.handleSegmentsLocked(true /* fastPath */); err != nil {
e.notifyProtocolGoroutine(notifyTickleWorker)
}
default:
@@ -839,13 +774,13 @@ func (e *endpoint) EndpointState() EndpointState {
// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
- e.recentTS = recentTS
+ e.RecentTS = recentTS
e.recentTSTime = time.Now()
}
// recentTimestamp returns the value of the recentTS field.
func (e *endpoint) recentTimestamp() uint32 {
- return e.recentTS
+ return e.RecentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -865,16 +800,17 @@ type keepalive struct {
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
stack: s,
- EndpointInfo: EndpointInfo{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.TCPProtocolNumber,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: header.TCPProtocolNumber,
+ },
+ sndQueueInfo: sndQueueInfo{
+ TCPSndBufState: stack.TCPSndBufState{
+ SndMTU: int(math.MaxInt32),
},
},
waiterQueue: waiterQueue,
state: StateInitial,
- rcvBufSize: DefaultReceiveBufferSize,
- sndMTU: int(math.MaxInt32),
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
@@ -886,10 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
- e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetQuickAck(true)
e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */)
+ e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -898,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- e.rcvBufSize = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
var cs tcpip.CongestionControlOption
@@ -908,7 +845,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var mrb tcpip.TCPModerateReceiveBufferOption
if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
- e.rcvAutoParams.disabled = !bool(mrb)
+ e.rcvQueueInfo.RcvAutoParams.Disabled = !bool(mrb)
}
var de tcpip.TCPDelayEnabled
@@ -933,7 +870,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.tsOffset = timeStampOffset()
+ e.TSOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(&e.keepalive.waker)
@@ -959,10 +896,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result = mask
case StateListen:
- // Check if there's anything in the accepted channel.
+ // Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if len(e.acceptedChan) > 0 {
+ if e.accepted.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -971,21 +908,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
if e.EndpointState().connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.WritableEvents) != 0 {
- e.sndBufMu.Lock()
+ e.sndQueueInfo.sndQueueMu.Lock()
sndBufSize := e.getSendBufferSize()
- if e.sndClosed || e.sndBufUsed < sndBufSize {
+ if e.sndQueueInfo.SndClosed || e.sndQueueInfo.SndBufUsed < sndBufSize {
result |= waiter.WritableEvents
}
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
}
// Determine if the endpoint is readable if requested.
if (mask & waiter.ReadableEvents) != 0 {
- e.rcvListMu.Lock()
- if e.rcvBufUsed > 0 || e.rcvClosed {
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ if e.rcvQueueInfo.RcvBufUsed > 0 || e.rcvQueueInfo.RcvClosed {
result |= waiter.ReadableEvents
}
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
}
}
@@ -1093,15 +1030,15 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
@@ -1145,22 +1082,22 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- if e.acceptedChan == nil {
- e.acceptMu.Unlock()
+ acceptedCopy := e.accepted
+ e.accepted = accepted{}
+ e.acceptMu.Unlock()
+
+ if acceptedCopy == (accepted{}) {
return
}
- close(e.acceptedChan)
- ch := e.acceptedChan
- e.acceptedChan = nil
+
e.acceptCond.Broadcast()
- e.acceptMu.Unlock()
// Reset all connections that are waiting to be accepted.
- for n := range ch {
- n.notifyProtocolGoroutine(notifyReset)
+ for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
// Wait for reset of all endpoints that are still waiting to be delivered to
- // the now closed acceptedChan.
+ // the now closed accepted.
e.pendingAccepted.Wait()
}
@@ -1176,7 +1113,7 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
@@ -1184,8 +1121,8 @@ func (e *endpoint) cleanupLocked() {
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
@@ -1247,19 +1184,19 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.LockUser()
defer e.UnlockUser()
- e.rcvListMu.Lock()
- if e.rcvAutoParams.disabled {
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ if e.rcvQueueInfo.RcvAutoParams.Disabled {
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
now := time.Now()
- if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
- e.rcvAutoParams.copied += copied
- e.rcvListMu.Unlock()
+ if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt {
+ e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- prevRTTCopied := e.rcvAutoParams.copied + copied
- prevCopied := e.rcvAutoParams.prevCopied
+ prevRTTCopied := e.rcvQueueInfo.RcvAutoParams.CopiedBytes + copied
+ prevCopied := e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes
rcvWnd := 0
if prevRTTCopied > prevCopied {
// The minimal receive window based on what was copied by the app
@@ -1291,24 +1228,25 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// We do not adjust downwards as that can cause the receiver to
// reject valid data that might already be in flight as the
// acceptable window will shrink.
- if rcvWnd > e.rcvBufSize {
- availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
- e.rcvBufSize = rcvWnd
- availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
- if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
+ rcvBufSize := int(e.ops.GetReceiveBufferSize())
+ if rcvWnd > rcvBufSize {
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize))
+ e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd))
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
}
- // We only update prevCopied when we grow the buffer because in cases
- // where prevCopied > prevRTTCopied the existing buffer is already big
+ // We only update PrevCopiedBytes when we grow the buffer because in cases
+ // where PrevCopiedBytes > prevRTTCopied the existing buffer is already big
// enough to handle the current rate and we don't need to do any
// adjustments.
- e.rcvAutoParams.prevCopied = prevRTTCopied
+ e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = prevRTTCopied
}
- e.rcvAutoParams.measureTime = now
- e.rcvAutoParams.copied = 0
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.RcvAutoParams.MeasureTime = now
+ e.rcvQueueInfo.RcvAutoParams.CopiedBytes = 0
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
}
// SetOwner implements tcpip.Endpoint.SetOwner.
@@ -1357,7 +1295,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
defer e.rcvReadMu.Unlock()
// N.B. Here we get a range of segments to be processed. It is safe to not
- // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // hold rcvQueueMu when processing, since we hold rcvReadMu to ensure only we
// can remove segments from the list through commitRead().
first, last, serr := e.startRead()
if serr != nil {
@@ -1429,10 +1367,10 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ defer e.rcvQueueInfo.rcvQueueMu.Unlock()
- bufUsed := e.rcvBufUsed
+ bufUsed := e.rcvQueueInfo.RcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
if s == StateError {
if err := e.hardErrorLocked(); err != nil {
@@ -1444,14 +1382,14 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
return nil, nil, &tcpip.ErrNotConnected{}
}
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
+ if e.rcvQueueInfo.RcvBufUsed == 0 {
+ if e.rcvQueueInfo.RcvClosed || !e.EndpointState().connected() {
return nil, nil, &tcpip.ErrClosedForReceive{}
}
return nil, nil, &tcpip.ErrWouldBlock{}
}
- return e.rcvList.Front(), e.rcvList.Back(), nil
+ return e.rcvQueueInfo.rcvQueue.Front(), e.rcvQueueInfo.rcvQueue.Back(), nil
}
// commitRead commits a read of done bytes and returns the next non-empty
@@ -1467,39 +1405,39 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
func (e *endpoint) commitRead(done int) *segment {
e.LockUser()
defer e.UnlockUser()
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ defer e.rcvQueueInfo.rcvQueueMu.Unlock()
memDelta := 0
- s := e.rcvList.Front()
+ s := e.rcvQueueInfo.rcvQueue.Front()
for s != nil && s.data.Size() == 0 {
- e.rcvList.Remove(s)
+ e.rcvQueueInfo.rcvQueue.Remove(s)
// Memory is only considered released when the whole segment has been
// read.
memDelta += s.segMemSize()
s.decRef()
- s = e.rcvList.Front()
+ s = e.rcvQueueInfo.rcvQueue.Front()
}
- e.rcvBufUsed -= done
+ e.rcvQueueInfo.RcvBufUsed -= done
if memDelta > 0 {
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
}
- return e.rcvList.Front()
+ return e.rcvQueueInfo.rcvQueue.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
// and also returns the number of bytes that can be written at this
// moment. If the endpoint is not writable then it returns an error
// indicating the reason why it's not writable.
-// Caller must hold e.mu and e.sndBufMu
+// Caller must hold e.mu and e.sndQueueMu
func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
@@ -1519,12 +1457,12 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
}
// Check if the connection has already been closed for sends.
- if e.sndClosed {
+ if e.sndQueueInfo.SndClosed {
return 0, &tcpip.ErrClosedForSend{}
}
sndBufSize := e.getSendBufferSize()
- avail := sndBufSize - e.sndBufUsed
+ avail := sndBufSize - e.sndQueueInfo.SndBufUsed
if avail <= 0 {
return 0, &tcpip.ErrWouldBlock{}
}
@@ -1541,8 +1479,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
defer e.UnlockUser()
nextSeg, n, err := func() (*segment, int, tcpip.Error) {
- e.sndBufMu.Lock()
- defer e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ defer e.sndQueueInfo.sndQueueMu.Unlock()
avail, err := e.isEndpointWritableLocked()
if err != nil {
@@ -1557,8 +1495,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// available buffer space to be consumed by some other caller while we
// are copying data in.
if !opts.Atomic {
- e.sndBufMu.Unlock()
- defer e.sndBufMu.Lock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
+ defer e.sndQueueInfo.sndQueueMu.Lock()
e.UnlockUser()
defer e.LockUser()
@@ -1600,10 +1538,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
// Add data to the send queue.
- s := newOutgoingSegment(e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, v)
+ e.sndQueueInfo.SndBufUsed += len(v)
+ e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v))
+ e.sndQueueInfo.sndQueue.PushBack(s)
return e.drainSendQueueLocked(), len(v), nil
}()
@@ -1618,11 +1556,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
-// Precondition: e.mu and e.rcvListMu must be held.
-func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
- wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked())
- maxWindow := wndFromSpace(e.rcvBufSize)
- wndFromUsedBytes := maxWindow - e.rcvBufUsed
+// Precondition: e.mu and e.rcvQueueMu must be held.
+func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) {
+ wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize))
+ maxWindow := wndFromSpace(rcvBufSize)
+ wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed
// We take the lesser of the wndFromAvailable and wndFromUsedBytes because in
// cases where we receive a lot of small segments the segment overhead is a
@@ -1640,11 +1578,11 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
return seqnum.Size(newWnd)
}
-// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu.
+// selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu.
func (e *endpoint) selectWindow() (wnd seqnum.Size) {
- e.rcvListMu.Lock()
- wnd = e.selectWindowLocked()
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize()))
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return wnd
}
@@ -1662,9 +1600,9 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) {
// above will be true if the new window is >= ACK threshold and false
// otherwise.
//
-// Precondition: e.mu and e.rcvListMu must be held.
-func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := int(e.selectWindowLocked())
+// Precondition: e.mu and e.rcvQueueMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) {
+ newAvail := int(e.selectWindowLocked(rcvBufSize))
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
@@ -1673,7 +1611,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
// rcvBufFraction is the inverse of the fraction of receive buffer size that
// is used to decide if the available buffer space is now above it.
const rcvBufFraction = 2
- if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold {
threshold = wndThreshold
}
switch {
@@ -1700,7 +1638,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
}
// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet.
-func (e *endpoint) OnKeepAliveSet(v bool) {
+func (e *endpoint) OnKeepAliveSet(bool) {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
}
@@ -1708,7 +1646,7 @@ func (e *endpoint) OnKeepAliveSet(v bool) {
func (e *endpoint) OnDelayOptionSet(v bool) {
if !v {
// Handle delayed data.
- e.sndWaker.Assert()
+ e.sndQueueInfo.sndWaker.Assert()
}
}
@@ -1716,7 +1654,7 @@ func (e *endpoint) OnDelayOptionSet(v bool) {
func (e *endpoint) OnCorkOptionSet(v bool) {
if !v {
// Handle the corked data.
- e.sndWaker.Assert()
+ e.sndQueueInfo.sndWaker.Assert()
}
}
@@ -1724,6 +1662,37 @@ func (e *endpoint) getSendBufferSize() int {
return int(e.ops.GetSendBufferSize())
}
+// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize.
+func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) {
+ e.LockUser()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.RcvWndScale
+ }
+ if rcvBufSz>>scale == 0 {
+ rcvBufSz = 1 << scale
+ }
+
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz)))
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz)))
+ e.rcvQueueInfo.RcvAutoParams.Disabled = true
+
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+ e.UnlockUser()
+ return rcvBufSz
+}
+
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
@@ -1767,56 +1736,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return &tcpip.ErrNotSupported{}
}
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs tcpip.TCPReceiveBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
- panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
- }
-
- if v > rs.Max {
- v = rs.Max
- }
-
- if v < math.MaxInt32/SegOverheadFactor {
- v *= SegOverheadFactor
- if v < rs.Min {
- v = rs.Min
- }
- } else {
- v = math.MaxInt32
- }
-
- e.LockUser()
- e.rcvListMu.Lock()
-
- // Make sure the receive buffer size allows us to send a
- // non-zero window size.
- scale := uint8(0)
- if e.rcv != nil {
- scale = e.rcv.rcvWndScale
- }
- if v>>scale == 0 {
- v = 1 << scale
- }
-
- availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
- e.rcvBufSize = v
- availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
-
- e.rcvAutoParams.disabled = true
-
- // Immediately send an ACK to uncork the sender silly window
- // syndrome prevetion, when our available space grows above aMSS
- // or half receive buffer, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
- }
-
- e.rcvListMu.Unlock()
- e.UnlockUser()
-
case tcpip.TTLOption:
e.LockUser()
e.ttl = uint8(v)
@@ -1959,10 +1878,10 @@ func (e *endpoint) readyReceiveSize() (int, tcpip.Error) {
return 0, &tcpip.ErrInvalidEndpointState{}
}
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ defer e.rcvQueueInfo.rcvQueueMu.Unlock()
- return e.rcvBufUsed, nil
+ return e.rcvQueueInfo.RcvBufUsed, nil
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -2002,12 +1921,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
- case tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- v := e.rcvBufSize
- e.rcvListMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.LockUser()
v := int(e.ttl)
@@ -2043,15 +1956,15 @@ func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption {
// the connection did not send and receive data, then RTT will
// be zero.
snd.rtt.Lock()
- info.RTT = snd.rtt.srtt
- info.RTTVar = snd.rtt.rttvar
+ info.RTT = snd.rtt.TCPRTTState.SRTT
+ info.RTTVar = snd.rtt.TCPRTTState.RTTVar
snd.rtt.Unlock()
- info.RTO = snd.rto
+ info.RTO = snd.RTO
info.CcState = snd.state
- info.SndSsthresh = uint32(snd.sndSsthresh)
- info.SndCwnd = uint32(snd.sndCwnd)
- info.ReorderSeen = snd.rc.reorderSeen
+ info.SndSsthresh = uint32(snd.Ssthresh)
+ info.SndCwnd = uint32(snd.SndCwnd)
+ info.ReorderSeen = snd.rc.Reord
}
e.UnlockUser()
return info
@@ -2096,7 +2009,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
e.UnlockUser()
if err != nil {
return err
@@ -2204,20 +2117,20 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.TransportEndpointInfo.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
defer r.Release()
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- e.ID.LocalAddress = r.LocalAddress()
- e.ID.RemoteAddress = r.RemoteAddress()
- e.ID.RemotePort = addr.Port
+ e.TransportEndpointInfo.ID.LocalAddress = r.LocalAddress()
+ e.TransportEndpointInfo.ID.RemoteAddress = r.RemoteAddress()
+ e.TransportEndpointInfo.ID.RemotePort = addr.Port
- if e.ID.LocalPort != 0 {
+ if e.TransportEndpointInfo.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -2226,19 +2139,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// one. Make sure that it isn't one that will result in the same
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
- sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress
+ sameAddr := e.TransportEndpointInfo.ID.LocalAddress == e.TransportEndpointInfo.ID.RemoteAddress
// Calculate a port offset based on the destination IP/port and
// src IP to ensure that for a given tuple (srcIP, destIP,
// destPort) the offset used as a starting point is the same to
// ensure that we can cycle through the port space effectively.
- h := jenkins.Sum32(e.stack.Seed())
- h.Write([]byte(e.ID.LocalAddress))
- h.Write([]byte(e.ID.RemoteAddress))
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h.Write(portBuf)
- portOffset := uint16(h.Sum32())
+
+ h := jenkins.Sum32(e.stack.Seed())
+ for _, s := range [][]byte{
+ []byte(e.ID.LocalAddress),
+ []byte(e.ID.RemoteAddress),
+ portBuf,
+ } {
+ // Per io.Writer.Write:
+ //
+ // Write must return a non-nil error if it returns n < len(p).
+ if _, err := h.Write(s); err != nil {
+ panic(err)
+ }
+ }
+ portOffset := h.Sum32()
var twReuse tcpip.TCPTimeWaitReuseOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
@@ -2249,21 +2172,21 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
switch netProto {
case header.IPv4ProtocolNumber:
- reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ reuse = header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.LocalAddress) && header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.RemoteAddress)
case header.IPv6ProtocolNumber:
- reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ reuse = e.TransportEndpointInfo.ID.LocalAddress == header.IPv6Loopback && e.TransportEndpointInfo.ID.RemoteAddress == header.IPv6Loopback
}
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) {
- if sameAddr && p == e.ID.RemotePort {
+ if sameAddr && p == e.TransportEndpointInfo.ID.RemotePort {
return false, nil
}
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
Port: p,
Flags: e.portFlags,
BindToDevice: bindToDevice,
@@ -2273,7 +2196,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
- transEPID := e.ID
+ transEPID := e.TransportEndpointInfo.ID
transEPID.LocalPort = p
// Check if an endpoint is registered with demuxer in TIME-WAIT and if
// we can reuse it. If we can't find a transport endpoint then we just
@@ -2310,7 +2233,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
Port: p,
Flags: e.portFlags,
BindToDevice: bindToDevice,
@@ -2321,13 +2244,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
}
}
- id := e.ID
+ id := e.TransportEndpointInfo.ID
id.LocalPort = p
if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
Port: p,
Flags: e.portFlags,
BindToDevice: bindToDevice,
@@ -2342,13 +2265,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// Port picking successful. Save the details of
// the selected port.
- e.ID = id
+ e.TransportEndpointInfo.ID = id
e.isPortReserved = true
e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
e.boundDest = addr
return true, nil
}); err != nil {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return err
}
}
@@ -2367,10 +2291,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// connection setting here.
if !handshake {
e.segmentQueue.mu.Lock()
- for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
- s.id = e.ID
- e.sndWaker.Assert()
+ s.id = e.TransportEndpointInfo.ID
+ e.sndQueueInfo.sndWaker.Assert()
}
}
e.segmentQueue.mu.Unlock()
@@ -2412,10 +2336,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Close for read.
if e.shutdownFlags&tcpip.ShutdownRead != 0 {
// Mark read side as closed.
- e.rcvListMu.Lock()
- e.rcvClosed = true
- rcvBufUsed := e.rcvBufUsed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = true
+ rcvBufUsed := e.rcvQueueInfo.RcvBufUsed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
@@ -2429,10 +2353,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Close for write.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- e.sndBufMu.Lock()
- if e.sndClosed {
+ e.sndQueueInfo.sndQueueMu.Lock()
+ if e.sndQueueInfo.SndClosed {
// Already closed.
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
if e.EndpointState() == StateTimeWait {
return &tcpip.ErrNotConnected{}
}
@@ -2440,12 +2364,12 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
}
// Queue fin segment.
- s := newOutgoingSegment(e.ID, nil)
- e.sndQueue.PushBack(s)
- e.sndBufInQueue++
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil)
+ e.sndQueueInfo.sndQueue.PushBack(s)
+ e.sndQueueInfo.SndBufInQueue++
// Mark endpoint as closed.
- e.sndClosed = true
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.SndClosed = true
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.handleClose()
}
@@ -2458,9 +2382,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
//
// By not removing this endpoint from the demuxer mapping, we
// ensure that any other bind to the same port fails, as on Linux.
- e.rcvListMu.Lock()
- e.rcvClosed = true
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = true
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
e.closePendingAcceptableConnectionsLocked()
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
@@ -2474,6 +2398,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
func (e *endpoint) Listen(backlog int) tcpip.Error {
+ // Accept one more than the configured listen backlog to keep in parity with
+ // Linux. Ref, because of missing equality check here:
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937
+ backlog++
err := e.listen(backlog)
if err != nil {
if !err.IgnoreStats() {
@@ -2491,28 +2419,20 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.acceptedChan == nil {
+ if e.accepted == (accepted{}) {
// listen is called after shutdown.
- e.acceptedChan = make(chan *endpoint, backlog)
+ e.accepted.cap = backlog
e.shutdownFlags = 0
- e.rcvListMu.Lock()
- e.rcvClosed = false
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = false
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
} else {
- // Adjust the size of the channel iff we can fix
+ // Adjust the size of the backlog iff we can fit
// existing pending connections into the new one.
- if len(e.acceptedChan) > backlog {
+ if e.accepted.endpoints.Len() > backlog {
return &tcpip.ErrInvalidEndpointState{}
}
- if cap(e.acceptedChan) == backlog {
- return nil
- }
- origChan := e.acceptedChan
- e.acceptedChan = make(chan *endpoint, backlog)
- close(origChan)
- for ep := range origChan {
- e.acceptedChan <- ep
- }
+ e.accepted.cap = backlog
}
// Notify any blocked goroutines that they can attempt to
@@ -2538,19 +2458,19 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
e.isRegistered = true
e.setEndpointState(StateListen)
- // The channel may be non-nil when we're restoring the endpoint, and it
+ // The queue may be non-zero when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.acceptedChan == nil {
- e.acceptedChan = make(chan *endpoint, backlog)
+ if e.accepted == (accepted{}) {
+ e.accepted.cap = backlog
}
e.acceptMu.Unlock()
@@ -2578,24 +2498,25 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
e.LockUser()
defer e.UnlockUser()
- e.rcvListMu.Lock()
- rcvClosed := e.rcvClosed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ rcvClosed := e.rcvQueueInfo.RcvClosed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
if rcvClosed || e.EndpointState() != StateListen {
return nil, nil, &tcpip.ErrInvalidEndpointState{}
}
// Get the new accepted endpoint.
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
var n *endpoint
- select {
- case n = <-e.acceptedChan:
- e.acceptCond.Signal()
- default:
+ e.acceptMu.Lock()
+ if element := e.accepted.endpoints.Front(); element != nil {
+ n = e.accepted.endpoints.Remove(element).(*endpoint)
+ }
+ e.acceptMu.Unlock()
+ if n == nil {
return nil, nil, &tcpip.ErrWouldBlock{}
}
+ e.acceptCond.Signal()
if peerAddr != nil {
*peerAddr = n.getRemoteAddress()
}
@@ -2645,7 +2566,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
if nic == 0 {
return &tcpip.ErrBadLocalAddress{}
}
- e.ID.LocalAddress = addr.Addr
+ e.TransportEndpointInfo.ID.LocalAddress = addr.Addr
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
@@ -2659,7 +2580,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
Dest: tcpip.FullAddress{},
}
port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
- id := e.ID
+ id := e.TransportEndpointInfo.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
// listening endpoint bound with the same id and portFlags and bindToDevice
@@ -2675,6 +2596,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
return true, nil
})
if err != nil {
+ e.stack.Stats().TCP.FailedPortReservations.Increment()
return err
}
@@ -2684,7 +2606,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
e.boundNICID = nic
e.isPortReserved = true
e.effectiveNetProtos = netProtos
- e.ID.LocalPort = port
+ e.TransportEndpointInfo.ID.LocalPort = port
// Mark endpoint as bound.
e.setEndpointState(StateBound)
@@ -2698,8 +2620,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
defer e.UnlockUser()
return tcpip.FullAddress{
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
NIC: e.boundNICID,
}, nil
}
@@ -2718,8 +2640,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
return tcpip.FullAddress{
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: e.TransportEndpointInfo.ID.RemoteAddress,
+ Port: e.TransportEndpointInfo.ID.RemotePort,
NIC: e.boundNICID,
}
}
@@ -2758,13 +2680,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
Payload: pkt.Data().AsRange().ToOwnedView(),
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: e.TransportEndpointInfo.ID.RemoteAddress,
+ Port: e.TransportEndpointInfo.ID.RemotePort,
},
Offender: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: e.TransportEndpointInfo.ID.LocalAddress,
+ Port: e.TransportEndpointInfo.ID.LocalPort,
},
NetProto: pkt.NetworkProtocolNumber,
})
@@ -2777,12 +2699,12 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
// HandleError implements stack.TransportEndpoint.
func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) {
handlePacketTooBig := func(mtu uint32) {
- e.sndBufMu.Lock()
- e.packetTooBigCount++
- if v := int(mtu); v < e.sndMTU {
- e.sndMTU = v
+ e.sndQueueInfo.sndQueueMu.Lock()
+ e.sndQueueInfo.PacketTooBigCount++
+ if v := int(mtu); v < e.sndQueueInfo.SndMTU {
+ e.sndQueueInfo.SndMTU = v
}
- e.sndBufMu.Unlock()
+ e.sndQueueInfo.sndQueueMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
}
@@ -2801,14 +2723,14 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// in the send buffer. The number of newly available bytes is v.
func (e *endpoint) updateSndBufferUsage(v int) {
sendBufferSize := e.getSendBufferSize()
- e.sndBufMu.Lock()
- notify := e.sndBufUsed >= sendBufferSize>>1
- e.sndBufUsed -= v
+ e.sndQueueInfo.sndQueueMu.Lock()
+ notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1
+ e.sndQueueInfo.SndBufUsed -= v
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndBufUsed < int(sendBufferSize)>>1
- e.sndBufMu.Unlock()
+ notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1
+ e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
e.waiterQueue.Notify(waiter.WritableEvents)
@@ -2819,58 +2741,50 @@ func (e *endpoint) updateSndBufferUsage(v int) {
// to be read, or when the connection is closed for receiving (in which case
// s will be nil).
func (e *endpoint) readyToRead(s *segment) {
- e.rcvListMu.Lock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
if s != nil {
- e.rcvBufUsed += s.payloadSize()
+ e.rcvQueueInfo.RcvBufUsed += s.payloadSize()
s.incRef()
- e.rcvList.PushBack(s)
+ e.rcvQueueInfo.rcvQueue.PushBack(s)
} else {
- e.rcvClosed = true
+ e.rcvQueueInfo.RcvClosed = true
}
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
e.waiterQueue.Notify(waiter.ReadableEvents)
}
// receiveBufferAvailableLocked calculates how many bytes are still available
// in the receive buffer.
-// rcvListMu must be held when this function is called.
-func (e *endpoint) receiveBufferAvailableLocked() int {
+// rcvQueueMu must be held when this function is called.
+func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
memUsed := e.receiveMemUsed()
- if memUsed >= e.rcvBufSize {
+ if memUsed >= rcvBufSize {
return 0
}
- return e.rcvBufSize - memUsed
+ return rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
// receive buffer based on the actual memory used by all segments held in
// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
- e.rcvListMu.Lock()
- available := e.receiveBufferAvailableLocked()
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize()))
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return available
}
// receiveBufferUsed returns the amount of in-use receive buffer.
func (e *endpoint) receiveBufferUsed() int {
- e.rcvListMu.Lock()
- used := e.rcvBufUsed
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ used := e.rcvQueueInfo.RcvBufUsed
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
return used
}
-// receiveBufferSize returns the current size of the receive buffer.
-func (e *endpoint) receiveBufferSize() int {
- e.rcvListMu.Lock()
- size := e.rcvBufSize
- e.rcvListMu.Unlock()
- return size
-}
-
// receiveMemUsed returns the total memory in use by segments held by this
// endpoint.
func (e *endpoint) receiveMemUsed() int {
@@ -2899,11 +2813,11 @@ func (e *endpoint) maxReceiveBufferSize() int {
// receiveBuffer otherwise we use the max permissible receive buffer size to
// compute the scale.
func (e *endpoint) rcvWndScaleForHandshake() int {
- bufSizeForScale := e.receiveBufferSize()
+ bufSizeForScale := e.ops.GetReceiveBufferSize()
- e.rcvListMu.Lock()
- autoTuningDisabled := e.rcvAutoParams.disabled
- e.rcvListMu.Unlock()
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
if autoTuningDisabled {
return FindWndScale(seqnum.Size(bufSizeForScale))
}
@@ -2914,7 +2828,7 @@ func (e *endpoint) rcvWndScaleForHandshake() int {
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
- if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ if e.SendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
e.setRecentTimestamp(tsVal)
}
}
@@ -2924,7 +2838,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
// initializes the recentTS with the value provided in synOpts.TSval.
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
- e.sendTSOk = true
+ e.SendTSOk = true
e.setRecentTimestamp(synOpts.TSVal)
}
}
@@ -2932,7 +2846,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(time.Now(), e.tsOffset)
+ return tcpTimeStamp(time.Now(), e.TSOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
@@ -2971,7 +2885,7 @@ func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
return
}
if bool(v) && synOpts.SACKPermitted {
- e.sackPermitted = true
+ e.SACKPermitted = true
}
}
@@ -2985,118 +2899,46 @@ func (e *endpoint) maxOptionSize() (size int) {
return size
}
-// completeState makes a full copy of the endpoint and returns it. This is used
-// before invoking the probe. The state returned may not be fully consistent if
-// there are intervening syscalls when the state is being copied.
-func (e *endpoint) completeState() stack.TCPEndpointState {
- var s stack.TCPEndpointState
- s.SegTime = time.Now()
-
- // Copy EndpointID.
- s.ID = stack.TCPEndpointID(e.ID)
-
- // Copy endpoint rcv state.
- e.rcvListMu.Lock()
- s.RcvBufSize = e.rcvBufSize
- s.RcvBufUsed = e.rcvBufUsed
- s.RcvClosed = e.rcvClosed
- s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime
- s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied
- s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied
- s.RcvAutoParams.RTT = e.rcvAutoParams.rtt
- s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber
- s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime
- s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled
- e.rcvListMu.Unlock()
-
- // Endpoint TCP Option state.
- s.SendTSOk = e.sendTSOk
- s.RecentTS = e.recentTimestamp()
- s.TSOffset = e.tsOffset
- s.SACKPermitted = e.sackPermitted
+// completeStateLocked makes a full copy of the endpoint and returns it. This is
+// used before invoking the probe.
+//
+// Precondition: e.mu must be held.
+func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
+ s := stack.TCPEndpointState{
+ TCPEndpointStateInner: e.TCPEndpointStateInner,
+ ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID),
+ SegTime: time.Now(),
+ Receiver: e.rcv.TCPReceiverState,
+ Sender: e.snd.TCPSenderState,
+ }
+
+ sndBufSize := e.getSendBufferSize()
+ // Copy the send buffer atomically.
+ e.sndQueueInfo.sndQueueMu.Lock()
+ s.SndBufState = e.sndQueueInfo.TCPSndBufState
+ s.SndBufState.SndBufSize = sndBufSize
+ e.sndQueueInfo.sndQueueMu.Unlock()
+
+ // Copy the receive buffer atomically.
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ s.RcvBufState = e.rcvQueueInfo.TCPRcvBufState
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+
+ // Copy the endpoint TCP Option state.
s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
- // Copy endpoint send state.
- sndBufSize := e.getSendBufferSize()
- e.sndBufMu.Lock()
- s.SndBufSize = sndBufSize
- s.SndBufUsed = e.sndBufUsed
- s.SndClosed = e.sndClosed
- s.SndBufInQueue = e.sndBufInQueue
- s.PacketTooBigCount = e.packetTooBigCount
- s.SndMTU = e.sndMTU
- e.sndBufMu.Unlock()
-
- // Copy receiver state.
- s.Receiver = stack.TCPReceiverState{
- RcvNxt: e.rcv.rcvNxt,
- RcvAcc: e.rcv.rcvAcc,
- RcvWndScale: e.rcv.rcvWndScale,
- PendingBufUsed: e.rcv.pendingBufUsed,
- }
-
- // Copy sender state.
- s.Sender = stack.TCPSenderState{
- LastSendTime: e.snd.lastSendTime,
- DupAckCount: e.snd.dupAckCount,
- FastRecovery: stack.TCPFastRecoveryState{
- Active: e.snd.fr.active,
- First: e.snd.fr.first,
- Last: e.snd.fr.last,
- MaxCwnd: e.snd.fr.maxCwnd,
- HighRxt: e.snd.fr.highRxt,
- RescueRxt: e.snd.fr.rescueRxt,
- },
- SndCwnd: e.snd.sndCwnd,
- Ssthresh: e.snd.sndSsthresh,
- SndCAAckCount: e.snd.sndCAAckCount,
- Outstanding: e.snd.outstanding,
- SackedOut: e.snd.sackedOut,
- SndWnd: e.snd.sndWnd,
- SndUna: e.snd.sndUna,
- SndNxt: e.snd.sndNxt,
- RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
- RTTMeasureTime: e.snd.rttMeasureTime,
- Closed: e.snd.closed,
- RTO: e.snd.rto,
- MaxPayloadSize: e.snd.maxPayloadSize,
- SndWndScale: e.snd.sndWndScale,
- MaxSentAck: e.snd.maxSentAck,
- }
e.snd.rtt.Lock()
- s.Sender.SRTT = e.snd.rtt.srtt
- s.Sender.SRTTInited = e.snd.rtt.srttInited
+ s.Sender.RTTState = e.snd.rtt.TCPRTTState
e.snd.rtt.Unlock()
if cubic, ok := e.snd.cc.(*cubicState); ok {
- s.Sender.Cubic = stack.TCPCubicState{
- WMax: cubic.wMax,
- WLastMax: cubic.wLastMax,
- T: cubic.t,
- TimeSinceLastCongestion: time.Since(cubic.t),
- C: cubic.c,
- K: cubic.k,
- Beta: cubic.beta,
- WC: cubic.wC,
- WEst: cubic.wEst,
- }
+ s.Sender.Cubic = cubic.TCPCubicState
+ s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T)
}
- rc := &e.snd.rc
- s.Sender.RACKState = stack.TCPRACKState{
- XmitTime: rc.xmitTime,
- EndSequence: rc.endSequence,
- FACK: rc.fack,
- RTT: rc.rtt,
- Reord: rc.reorderSeen,
- DSACKSeen: rc.dsackSeen,
- ReoWnd: rc.reoWnd,
- ReoWndIncr: rc.reoWndIncr,
- ReoWndPersist: rc.reoWndPersist,
- RTTSeq: rc.rttSeq,
- }
+ s.Sender.RACKState = e.snd.rc.TCPRACKState
return s
}
@@ -3200,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool {
e.lastOutOfWindowAckTime = now
return true
}
+
+// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP.
+func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption {
+ var ss tcpip.TCPReceiveBufferSizeRangeOption
+ if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err))
+ }
+
+ return tcpip.ReceiveBufferSizeOption{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index a53d76917..6e9777fe4 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -58,7 +58,7 @@ func (e *endpoint) beforeSave() {
if !e.route.HasSaveRestoreCapability() {
if !e.route.HasDisconncetOkCapability() {
panic(&tcpip.ErrSaveRejection{
- Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort),
+ Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort),
})
}
e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
@@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
}
if !e.workerRunning {
- // The endpoint must be in acceptedChan or has been just
+ // The endpoint must be in the accepted queue or has been just
// disconnected and closed.
break
}
@@ -88,7 +88,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
}
if e.workerRunning {
- panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID))
}
default:
panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
@@ -99,37 +99,19 @@ func (e *endpoint) beforeSave() {
}
}
-// saveAcceptedChan is invoked by stateify.
-func (e *endpoint) saveAcceptedChan() []*endpoint {
- if e.acceptedChan == nil {
- return nil
- }
- acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
- for i := 0; i < len(acceptedEndpoints); i++ {
- select {
- case ep := <-e.acceptedChan:
- acceptedEndpoints[i] = ep
- default:
- panic("endpoint acceptedChan buffer got consumed by background context")
- }
- }
- for i := 0; i < len(acceptedEndpoints); i++ {
- select {
- case e.acceptedChan <- acceptedEndpoints[i]:
- default:
- panic("endpoint acceptedChan buffer got populated by background context")
- }
+// saveEndpoints is invoked by stateify.
+func (a *accepted) saveEndpoints() []*endpoint {
+ acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
+ for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
+ acceptedEndpoints[i] = e.Value.(*endpoint)
}
return acceptedEndpoints
}
-// loadAcceptedChan is invoked by stateify.
-func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
- if cap(acceptedEndpoints) > 0 {
- e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
- for _, ep := range acceptedEndpoints {
- e.acceptedChan <- ep
- }
+// loadEndpoints is invoked by stateify.
+func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+ for _, ep := range acceptedEndpoints {
+ a.endpoints.PushBack(ep)
}
}
@@ -183,7 +165,7 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
@@ -198,14 +180,14 @@ func (e *endpoint) Resume(s *stack.Stack) {
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
+ if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) {
+ panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max))
}
}
}
bind := func() {
- addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort})
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
@@ -231,19 +213,19 @@ func (e *endpoint) Resume(s *stack.Stack) {
case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
- e.connectingAddress = e.ID.RemoteAddress
+ e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress
// This endpoint is accepted by netstack but not yet by
// the app. If the endpoint is IPv6 but the remote
// address is IPv4, we need to connect as IPv6 so that
// dual-stack mode can be properly activated.
- if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize {
- e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress
+ if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress
}
}
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
- err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning)
+ err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning)
if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
@@ -263,7 +245,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
- backlog := cap(e.acceptedChan)
+ backlog := e.accepted.cap
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
@@ -281,7 +263,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
listenLoading.Wait()
bind()
- err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort})
+ err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort})
if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
@@ -328,23 +310,3 @@ func (e *endpoint) saveLastOutOfWindowAckTime() unixTime {
func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) {
e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano)
}
-
-// saveMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime {
- return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()}
-}
-
-// loadMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
- r.measureTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRttMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime {
- return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()}
-}
-
-// loadRttMeasureTime is invoked by stateify.
-func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) {
- r.rttMeasureTime = time.Unix(unix.second, unix.nano)
-}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a4667906..fe0d7f10f 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -75,63 +75,6 @@ const (
ccCubic = "cubic"
)
-// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
-// value is protected by a mutex so that we can increment only when it's
-// guaranteed not to go above a threshold.
-type synRcvdCounter struct {
- sync.Mutex
- value uint64
- pending sync.WaitGroup
- threshold uint64
-}
-
-// inc tries to increment the global number of endpoints in SYN-RCVD state. It
-// succeeds if the increment doesn't make the count go beyond the threshold, and
-// fails otherwise.
-func (s *synRcvdCounter) inc() bool {
- s.Lock()
- defer s.Unlock()
- if s.value >= s.threshold {
- return false
- }
-
- s.pending.Add(1)
- s.value++
-
- return true
-}
-
-// dec atomically decrements the global number of endpoints in SYN-RCVD
-// state. It must only be called if a previous call to inc succeeded.
-func (s *synRcvdCounter) dec() {
- s.Lock()
- defer s.Unlock()
- s.value--
- s.pending.Done()
-}
-
-// synCookiesInUse returns true if the synRcvdCount is greater than
-// SynRcvdCountThreshold.
-func (s *synRcvdCounter) synCookiesInUse() bool {
- s.Lock()
- defer s.Unlock()
- return s.value >= s.threshold
-}
-
-// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
-func (s *synRcvdCounter) SetThreshold(threshold uint64) {
- s.Lock()
- defer s.Unlock()
- s.threshold = threshold
-}
-
-// Threshold returns the current value of synRcvdCounter.Threhsold.
-func (s *synRcvdCounter) Threshold() uint64 {
- s.Lock()
- defer s.Unlock()
- return s.threshold
-}
-
type protocol struct {
stack *stack.Stack
@@ -139,6 +82,7 @@ type protocol struct {
sackEnabled bool
recovery tcpip.TCPRecovery
delayEnabled bool
+ alwaysUseSynCookies bool
sendBufferSize tcpip.TCPSendBufferSizeRangeOption
recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption
congestionControl string
@@ -150,7 +94,6 @@ type protocol struct {
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
- synRcvdCount synRcvdCounter
synRetries uint8
dispatcher dispatcher
}
@@ -373,9 +316,9 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
p.mu.Unlock()
return nil
- case *tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPAlwaysUseSynCookies:
p.mu.Lock()
- p.synRcvdCount.SetThreshold(uint64(*v))
+ p.alwaysUseSynCookies = bool(*v)
p.mu.Unlock()
return nil
@@ -480,9 +423,9 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er
p.mu.RUnlock()
return nil
- case *tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPAlwaysUseSynCookies:
p.mu.RLock()
- *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ *v = tcpip.TCPAlwaysUseSynCookies(p.alwaysUseSynCookies)
p.mu.RUnlock()
return nil
@@ -507,12 +450,6 @@ func (p *protocol) Wait() {
p.dispatcher.wait()
}
-// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
-// instance.
-func (p *protocol) SynRcvdCounter() *synRcvdCounter {
- return &p.synRcvdCount
-}
-
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
return parse.TCP(pkt)
@@ -537,7 +474,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
lingerTimeout: DefaultTCPLingerTimeout,
timeWaitTimeout: DefaultTCPTimeWaitTimeout,
timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
- synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 0a0d5f7a1..9e332dcf7 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -46,54 +47,16 @@ const (
//
// +stateify savable
type rackControl struct {
- // dsackSeen indicates if the connection has seen a DSACK.
- dsackSeen bool
-
- // endSequence is the ending TCP sequence number of the most recent
- // acknowledged segment.
- endSequence seqnum.Value
+ stack.TCPRACKState
// exitedRecovery indicates if the connection is exiting loss recovery.
// This flag is set if the sender is leaving the recovery after
// receiving an ACK and is reset during updating of reorder window.
exitedRecovery bool
- // fack is the highest selectively or cumulatively acknowledged
- // sequence.
- fack seqnum.Value
-
// minRTT is the estimated minimum RTT of the connection.
minRTT time.Duration
- // reorderSeen indicates if reordering has been detected on this
- // connection.
- reorderSeen bool
-
- // reoWnd is the reordering window time used for recording packet
- // transmission times. It is used to defer the moment at which RACK
- // marks a packet lost.
- reoWnd time.Duration
-
- // reoWndIncr is the multiplier applied to adjust reorder window.
- reoWndIncr uint8
-
- // reoWndPersist is the number of loss recoveries before resetting
- // reorder window.
- reoWndPersist int8
-
- // rtt is the RTT of the most recently delivered packet on the
- // connection (either cumulatively acknowledged or selectively
- // acknowledged) that was not marked invalid as a possible spurious
- // retransmission.
- rtt time.Duration
-
- // rttSeq is the SND.NXT when rtt is updated.
- rttSeq seqnum.Value
-
- // xmitTime is the latest transmission timestamp of the most recent
- // acknowledged segment.
- xmitTime time.Time `state:".(unixTime)"`
-
// tlpRxtOut indicates whether there is an unacknowledged
// TLP retransmission.
tlpRxtOut bool
@@ -108,8 +71,8 @@ type rackControl struct {
// init initializes RACK specific fields.
func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
- rc.fack = iss
- rc.reoWndIncr = 1
+ rc.FACK = iss
+ rc.ReoWndIncr = 1
rc.snd = snd
}
@@ -117,7 +80,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := time.Now().Sub(seg.xmitTime)
- tsOffset := rc.snd.ep.tsOffset
+ tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
@@ -138,7 +101,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
}
}
- rc.rtt = rtt
+ rc.RTT = rtt
// The sender can either track a simple global minimum of all RTT
// measurements from the connection, or a windowed min-filtered value
@@ -152,9 +115,9 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// ending sequence number of the packet which has been acknowledged
// most recently.
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
- rc.xmitTime = seg.xmitTime
- rc.endSequence = endSeq
+ if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ rc.XmitTime = seg.xmitTime
+ rc.EndSequence = endSeq
}
}
@@ -171,18 +134,18 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// is identified.
func (rc *rackControl) detectReorder(seg *segment) {
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if rc.fack.LessThan(endSeq) {
- rc.fack = endSeq
+ if rc.FACK.LessThan(endSeq) {
+ rc.FACK = endSeq
return
}
- if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 {
- rc.reorderSeen = true
+ if endSeq.LessThan(rc.FACK) && seg.xmitCount == 1 {
+ rc.Reord = true
}
}
func (rc *rackControl) setDSACKSeen(dsackSeen bool) {
- rc.dsackSeen = dsackSeen
+ rc.DSACKSeen = dsackSeen
}
// shouldSchedulePTO dictates whether we should schedule a PTO or not.
@@ -191,7 +154,7 @@ func (s *sender) shouldSchedulePTO() bool {
// Schedule PTO only if RACK loss detection is enabled.
return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 &&
// The connection supports SACK.
- s.ep.sackPermitted &&
+ s.ep.SACKPermitted &&
// The connection is not in loss recovery.
(s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) &&
// The connection has no SACKed sequences in the SACK scoreboard.
@@ -203,9 +166,9 @@ func (s *sender) shouldSchedulePTO() bool {
func (s *sender) schedulePTO() {
pto := time.Second
s.rtt.Lock()
- if s.rtt.srttInited && s.rtt.srtt > 0 {
- pto = s.rtt.srtt * 2
- if s.outstanding == 1 {
+ if s.rtt.TCPRTTState.SRTTInited && s.rtt.TCPRTTState.SRTT > 0 {
+ pto = s.rtt.TCPRTTState.SRTT * 2
+ if s.Outstanding == 1 {
pto += wcDelayedACKTimeout
}
}
@@ -230,10 +193,10 @@ func (s *sender) probeTimerExpired() tcpip.Error {
}
var dataSent bool
- if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.outstanding < s.sndCwnd {
- dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd))
+ if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.Outstanding < s.SndCwnd {
+ dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd))
if dataSent {
- s.outstanding += s.pCount(s.writeNext, s.maxPayloadSize)
+ s.Outstanding += s.pCount(s.writeNext, s.MaxPayloadSize)
s.writeNext = s.writeNext.Next()
}
}
@@ -255,10 +218,10 @@ func (s *sender) probeTimerExpired() tcpip.Error {
}
if highestSeqXmit != nil {
- dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd))
+ dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd))
if dataSent {
s.rc.tlpRxtOut = true
- s.rc.tlpHighRxt = s.sndNxt
+ s.rc.tlpHighRxt = s.SndNxt
}
}
}
@@ -274,7 +237,7 @@ func (s *sender) probeTimerExpired() tcpip.Error {
// and updates TLP state accordingly.
// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3.
func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
- if !(s.ep.sackPermitted && s.rc.tlpRxtOut) {
+ if !(s.ep.SACKPermitted && s.rc.tlpRxtOut) {
return
}
@@ -317,13 +280,13 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
// retransmit quickly, or when the number of DUPACKs exceeds the classic
// DUPACKthreshold.
func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
- dsackSeen := rc.dsackSeen
+ dsackSeen := rc.DSACKSeen
snd := rc.snd
// React to DSACK once per round trip.
// If SND.UNA < RACK.rtt_seq:
// RACK.dsack = false
- if snd.sndUna.LessThan(rc.rttSeq) {
+ if snd.SndUna.LessThan(rc.RTTSeq) {
dsackSeen = false
}
@@ -333,18 +296,18 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
// RACK.rtt_seq = SND.NXT
// RACK.reo_wnd_persist = 16
if dsackSeen {
- rc.reoWndIncr++
+ rc.ReoWndIncr++
dsackSeen = false
- rc.rttSeq = snd.sndNxt
- rc.reoWndPersist = tcpRACKRecoveryThreshold
+ rc.RTTSeq = snd.SndNxt
+ rc.ReoWndPersist = tcpRACKRecoveryThreshold
} else if rc.exitedRecovery {
// Else if exiting loss recovery:
// RACK.reo_wnd_persist -= 1
// If RACK.reo_wnd_persist <= 0:
// RACK.reo_wnd_incr = 1
- rc.reoWndPersist--
- if rc.reoWndPersist <= 0 {
- rc.reoWndIncr = 1
+ rc.ReoWndPersist--
+ if rc.ReoWndPersist <= 0 {
+ rc.ReoWndIncr = 1
}
rc.exitedRecovery = false
}
@@ -358,14 +321,14 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
// Else if RACK.pkts_sacked >= RACK.dupthresh:
// RACK.reo_wnd = 0
// return
- if !rc.reorderSeen {
+ if !rc.Reord {
if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery {
- rc.reoWnd = 0
+ rc.ReoWnd = 0
return
}
- if snd.sackedOut >= nDupAckThreshold {
- rc.reoWnd = 0
+ if snd.SackedOut >= nDupAckThreshold {
+ rc.ReoWnd = 0
return
}
}
@@ -374,11 +337,11 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
// RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr
// RACK.reo_wnd = min(RACK.reo_wnd, SRTT)
snd.rtt.Lock()
- srtt := snd.rtt.srtt
+ srtt := snd.rtt.TCPRTTState.SRTT
snd.rtt.Unlock()
- rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr))
- if srtt < rc.reoWnd {
- rc.reoWnd = srtt
+ rc.ReoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.ReoWndIncr))
+ if srtt < rc.ReoWnd {
+ rc.ReoWnd = srtt
}
}
@@ -403,8 +366,8 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int {
}
endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
- if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
- timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd
+ if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) {
+ timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd
if timeRemaining <= 0 {
seg.lost = true
numLost++
@@ -435,7 +398,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error {
}
fastRetransmit := false
- if !rc.snd.fr.active {
+ if !rc.snd.FastRecovery.Active {
rc.snd.cc.HandleLossDetected()
rc.snd.enterRecovery()
fastRetransmit = true
@@ -471,15 +434,15 @@ func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) {
}
// Check the congestion window after entering recovery.
- if snd.outstanding >= snd.sndCwnd {
+ if snd.Outstanding >= snd.SndCwnd {
break
}
- if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.sndUna.Add(snd.sndWnd)); !sent {
+ if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.SndUna.Add(snd.SndWnd)); !sent {
break
}
dataSent = true
- snd.outstanding += snd.pCount(seg, snd.maxPayloadSize)
+ snd.Outstanding += snd.pCount(seg, snd.MaxPayloadSize)
}
snd.postXmit(dataSent, true /* shouldScheduleProbe */)
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index bc6793fc6..ee2c08cd6 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// receiver holds the state necessary to receive TCP segments and turn them
@@ -29,26 +30,15 @@ import (
//
// +stateify savable
type receiver struct {
+ stack.TCPReceiverState
ep *endpoint
- rcvNxt seqnum.Value
-
- // rcvAcc is one beyond the last acceptable sequence number. That is,
- // the "largest" sequence value that the receiver has announced to the
- // its peer that it's willing to accept. This may be different than
- // rcvNxt + rcvWnd if the receive window is reduced; in that case we
- // have to reduce the window as we receive more data instead of
- // shrinking it.
- rcvAcc seqnum.Value
-
// rcvWnd is the non-scaled receive window last advertised to the peer.
rcvWnd seqnum.Size
- // rcvWUP is the rcvNxt value at the last window update sent.
+ // rcvWUP is the RcvNxt value at the last window update sent.
rcvWUP seqnum.Value
- rcvWndScale uint8
-
// prevBufused is the snapshot of endpoint rcvBufUsed taken when we
// advertise a receive window.
prevBufUsed int
@@ -58,9 +48,6 @@ type receiver struct {
// pendingRcvdSegments is bounded by the receive buffer size of the
// endpoint.
pendingRcvdSegments segmentHeap
- // pendingBufUsed tracks the total number of bytes (including segment
- // overhead) currently queued in pendingRcvdSegments.
- pendingBufUsed int
// Time when the last ack was received.
lastRcvdAckTime time.Time `state:".(unixTime)"`
@@ -68,12 +55,14 @@ type receiver struct {
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
- ep: ep,
- rcvNxt: irs + 1,
- rcvAcc: irs.Add(rcvWnd + 1),
+ ep: ep,
+ TCPReceiverState: stack.TCPReceiverState{
+ RcvNxt: irs + 1,
+ RcvAcc: irs.Add(rcvWnd + 1),
+ RcvWndScale: rcvWndScale,
+ },
rcvWnd: rcvWnd,
rcvWUP: irs + 1,
- rcvWndScale: rcvWndScale,
lastRcvdAckTime: time.Now(),
}
}
@@ -84,34 +73,34 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// r.rcvWnd could be much larger than the window size we advertised in our
// outgoing packets, we should use what we have advertised for acceptability
// test.
- scaledWindowSize := r.rcvWnd >> r.rcvWndScale
+ scaledWindowSize := r.rcvWnd >> r.RcvWndScale
if scaledWindowSize > math.MaxUint16 {
// This is what we actually put in the Window field.
scaledWindowSize = math.MaxUint16
}
- advertisedWindowSize := scaledWindowSize << r.rcvWndScale
- return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
+ advertisedWindowSize := scaledWindowSize << r.RcvWndScale
+ return header.Acceptable(segSeq, segLen, r.RcvNxt, r.RcvNxt.Add(advertisedWindowSize))
}
// currentWindow returns the available space in the window that was advertised
// last to our peer.
func (r *receiver) currentWindow() (curWnd seqnum.Size) {
endOfWnd := r.rcvWUP.Add(r.rcvWnd)
- if endOfWnd.LessThan(r.rcvNxt) {
- // return 0 if r.rcvNxt is past the end of the previously advertised window.
+ if endOfWnd.LessThan(r.RcvNxt) {
+ // return 0 if r.RcvNxt is past the end of the previously advertised window.
// This can happen because we accept a large segment completely even if
// accepting it causes it to partially exceed the advertised window.
return 0
}
- return r.rcvNxt.Size(endOfWnd)
+ return r.RcvNxt.Size(endOfWnd)
}
// getSendParams returns the parameters needed by the sender when building
// segments to send.
-func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
+func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) {
newWnd := r.ep.selectWindow()
curWnd := r.currentWindow()
- unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt))
+ unackLen := int(r.ep.snd.MaxSentAck.Size(r.RcvNxt))
bufUsed := r.ep.receiveBufferUsed()
// Grow the right edge of the window only for payloads larger than the
@@ -139,18 +128,18 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// edge, as we are still advertising a window that we think can be serviced.
toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed
- // Update rcvAcc only if new window is > previously advertised window. We
+ // Update RcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
// would end up dropping bytes that might already be in flight.
// ==================================================== sequence space.
// ^ ^ ^ ^
- // rcvWUP rcvNxt rcvAcc new rcvAcc
+ // rcvWUP RcvNxt RcvAcc new RcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
- // If the new window moves the right edge, then update rcvAcc.
- r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
+ if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
+ // If the new window moves the right edge, then update RcvAcc.
+ r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd))
} else {
if newWnd == 0 {
// newWnd is zero but we can't advertise a zero as it would cause window
@@ -162,9 +151,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
r.rcvWnd = newWnd
- r.rcvWUP = r.rcvNxt
+ r.rcvWUP = r.RcvNxt
r.prevBufUsed = bufUsed
- scaledWnd := r.rcvWnd >> r.rcvWndScale
+ scaledWnd := r.rcvWnd >> r.RcvWndScale
if scaledWnd == 0 {
// Increment a metric if we are advertising an actual zero window.
r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
@@ -177,9 +166,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// Ensure that the stashed receive window always reflects what
// is being advertised.
- r.rcvWnd = scaledWnd << r.rcvWndScale
+ r.rcvWnd = scaledWnd << r.RcvWndScale
}
- return r.rcvNxt, scaledWnd
+ return r.RcvNxt, scaledWnd
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
@@ -201,13 +190,13 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// If the segment doesn't include the seqnum we're expecting to
// consume now, we're missing a segment. We cannot proceed until
// we receive that segment though.
- if !r.rcvNxt.InWindow(segSeq, segLen) {
+ if !r.RcvNxt.InWindow(segSeq, segLen) {
return false
}
// Trim segment to eliminate already acknowledged data.
- if segSeq.LessThan(r.rcvNxt) {
- diff := segSeq.Size(r.rcvNxt)
+ if segSeq.LessThan(r.RcvNxt) {
+ diff := segSeq.Size(r.RcvNxt)
segLen -= diff
segSeq.UpdateForward(diff)
s.sequenceNumber.UpdateForward(diff)
@@ -217,35 +206,35 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Move segment to ready-to-deliver list. Wakeup any waiters.
r.ep.readyToRead(s)
- } else if segSeq != r.rcvNxt {
+ } else if segSeq != r.RcvNxt {
return false
}
// Update the segment that we're expecting to consume.
- r.rcvNxt = segSeq.Add(segLen)
+ r.RcvNxt = segSeq.Add(segLen)
// In cases of a misbehaving sender which could send more than the
// advertised window, we could end up in a situation where we get a
// segment that exceeds the window advertised. Instead of partially
// accepting the segment and discarding bytes beyond the advertised
- // window, we accept the whole segment and make sure r.rcvAcc is moved
- // forward to match r.rcvNxt to indicate that the window is now closed.
+ // window, we accept the whole segment and make sure r.RcvAcc is moved
+ // forward to match r.RcvNxt to indicate that the window is now closed.
//
// In absence of this check the r.acceptable() check fails and accepts
// segments that should be dropped because rcvWnd is calculated as
- // the size of the interval (rcvNxt, rcvAcc] which becomes extremely
- // large if rcvAcc is ever less than rcvNxt.
- if r.rcvAcc.LessThan(r.rcvNxt) {
- r.rcvAcc = r.rcvNxt
+ // the size of the interval (RcvNxt, RcvAcc] which becomes extremely
+ // large if RcvAcc is ever less than RcvNxt.
+ if r.RcvAcc.LessThan(r.RcvNxt) {
+ r.RcvAcc = r.RcvNxt
}
// Trim SACK Blocks to remove any SACK information that covers
// sequence numbers that have been consumed.
- TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+ TrimSACKBlockList(&r.ep.sack, r.RcvNxt)
// Handle FIN or FIN-ACK.
if s.flagIsSet(header.TCPFlagFin) {
- r.rcvNxt++
+ r.RcvNxt++
// Send ACK immediately.
r.ep.snd.sendAck()
@@ -260,7 +249,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
case StateEstablished:
r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
// FIN-ACK, transition to TIME-WAIT.
r.ep.setEndpointState(StateTimeWait)
} else {
@@ -280,7 +269,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
- r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
+ r.PendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
r.pendingRcvdSegments[i].decRef()
// Note that slice truncation does not allow garbage collection of
@@ -295,7 +284,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt {
switch r.ep.EndpointState() {
case StateFinWait1:
r.ep.setEndpointState(StateFinWait2)
@@ -323,40 +312,40 @@ func (r *receiver) updateRTT() {
// estimate the round-trip time by observing the time between when a byte
// is first acknowledged and the receipt of data that is at least one
// window beyond the sequence number that was acknowledged.
- r.ep.rcvListMu.Lock()
- if r.ep.rcvAutoParams.rttMeasureTime.IsZero() {
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() {
// New measurement.
- r.ep.rcvAutoParams.rttMeasureTime = time.Now()
- r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) {
- r.ep.rcvListMu.Unlock()
+ if r.RcvNxt.LessThan(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber) {
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
return
}
- rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime)
+ rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime)
// We only store the minimum observed RTT here as this is only used in
// absence of a SRTT available from either timestamps or a sender
// measurement of RTT.
- if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt {
- r.ep.rcvAutoParams.rtt = rtt
+ if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT {
+ r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt
}
- r.ep.rcvAutoParams.rttMeasureTime = time.Now()
- r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now()
+ r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd)
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
}
func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) {
- r.ep.rcvListMu.Lock()
- rcvClosed := r.ep.rcvClosed || r.closed
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ rcvClosed := r.ep.rcvQueueInfo.RcvClosed || r.closed
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
// If we are in one of the shutdown states then we need to do
// additional checks before we try and process the segment.
switch state {
case StateCloseWait, StateClosing, StateLastAck:
- if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ if !s.sequenceNumber.LessThanEq(r.RcvNxt) {
// Just drop the segment as we have
// already received a FIN and this
// segment is after the sequence number
@@ -384,17 +373,17 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// The ESTABLISHED state processing is here where if the ACK check
// fails, we ignore the packet:
// https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591
- if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
+ if r.ep.snd.SndNxt.LessThan(s.ackNumber) {
r.ep.snd.maybeSendOutOfWindowAck(s)
return true, nil
}
// If we are closed for reads (either due to an
// incoming FIN or the user calling shutdown(..,
- // SHUT_RD) then any data past the rcvNxt should
+ // SHUT_RD) then any data past the RcvNxt should
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
- if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ if state != StateCloseWait && rcvClosed && r.RcvNxt.LessThan(endDataSeq) {
return true, &tcpip.ErrConnectionAborted{}
}
if state == StateFinWait1 {
@@ -403,7 +392,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// If it's a retransmission of an old data segment
// or a pure ACK then allow it.
- if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.RcvNxt) ||
s.logicalLen() == 0 {
break
}
@@ -413,7 +402,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// then the only acceptable segment is a
// FIN. Since FIN can technically also carry
// data we verify that the segment carrying a
- // FIN ends at exactly e.rcvNxt+1.
+ // FIN ends at exactly e.RcvNxt+1.
//
// From RFC793 page 25.
//
@@ -423,7 +412,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// while the FIN is considered to occur after
// the last actual data octet in a segment in
// which it occurs.
- if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) {
return true, &tcpip.ErrConnectionAborted{}
}
}
@@ -435,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// end has closed and the peer is yet to send a FIN. Hence we
// compare only the payload.
segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
- if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ if rcvClosed && !segEnd.LessThanEq(r.RcvNxt) {
return true, nil
}
return false, nil
@@ -477,13 +466,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 {
- r.ep.rcvListMu.Lock()
- r.pendingBufUsed += s.segMemSize()
- r.ep.rcvListMu.Unlock()
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ r.PendingBufUsed += s.segMemSize()
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
- UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.RcvNxt)
}
// Immediately send an ack so that the peer knows it may
@@ -508,15 +497,15 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
segSeq := s.sequenceNumber
// Skip segment altogether if it has already been acknowledged.
- if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
+ if !segSeq.Add(segLen-1).LessThan(r.RcvNxt) &&
!r.consumeSegment(s, segSeq, segLen) {
break
}
heap.Pop(&r.pendingRcvdSegments)
- r.ep.rcvListMu.Lock()
- r.pendingBufUsed -= s.segMemSize()
- r.ep.rcvListMu.Unlock()
+ r.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ r.PendingBufUsed -= s.segMemSize()
+ r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
s.decRef()
}
return false, nil
@@ -558,7 +547,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// (2) returns to TIME-WAIT state if the SYN turns out
// to be an old duplicate".
- if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+ if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
return false, true
}
@@ -569,11 +558,11 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if r.ep.sendTSOk && s.parsedOptions.TS {
- r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ if r.ep.SendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq)
}
- if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) {
// If it's a FIN-ACK then resetTimeWait and send an ACK, as it
// indicates our final ACK could have been lost.
r.ep.snd.sendAck()
@@ -584,8 +573,8 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// carries data then just send an ACK. This is according to RFC 793,
// page 37.
//
- // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
- if segSeq != r.rcvNxt || segLen != 0 {
+ // NOTE: In TIME_WAIT the only acceptable sequence number is RcvNxt.
+ if segSeq != r.RcvNxt || segLen != 0 {
r.ep.snd.sendAck()
}
return false, false
diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go
index ff39780a5..063552c7f 100644
--- a/pkg/tcpip/transport/tcp/reno.go
+++ b/pkg/tcpip/transport/tcp/reno.go
@@ -34,14 +34,14 @@ func newRenoCC(s *sender) *renoState {
func (r *renoState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
- newcwnd := r.s.sndCwnd + packetsAcked
- if newcwnd >= r.s.sndSsthresh {
- newcwnd = r.s.sndSsthresh
- r.s.sndCAAckCount = 0
+ newcwnd := r.s.SndCwnd + packetsAcked
+ if newcwnd >= r.s.Ssthresh {
+ newcwnd = r.s.Ssthresh
+ r.s.SndCAAckCount = 0
}
- packetsAcked -= newcwnd - r.s.sndCwnd
- r.s.sndCwnd = newcwnd
+ packetsAcked -= newcwnd - r.s.SndCwnd
+ r.s.SndCwnd = newcwnd
return packetsAcked
}
@@ -49,19 +49,19 @@ func (r *renoState) updateSlowStart(packetsAcked int) int {
// avoidance mode as described in RFC5681 section 3.1
func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
// Consume the packets in congestion avoidance mode.
- r.s.sndCAAckCount += packetsAcked
- if r.s.sndCAAckCount >= r.s.sndCwnd {
- r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
- r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
+ r.s.SndCAAckCount += packetsAcked
+ if r.s.SndCAAckCount >= r.s.SndCwnd {
+ r.s.SndCwnd += r.s.SndCAAckCount / r.s.SndCwnd
+ r.s.SndCAAckCount = r.s.SndCAAckCount % r.s.SndCwnd
}
}
// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
// page 6, eq. 4. It is called when we detect congestion in the network.
func (r *renoState) reduceSlowStartThreshold() {
- r.s.sndSsthresh = r.s.outstanding / 2
- if r.s.sndSsthresh < 2 {
- r.s.sndSsthresh = 2
+ r.s.Ssthresh = r.s.Outstanding / 2
+ if r.s.Ssthresh < 2 {
+ r.s.Ssthresh = 2
}
}
@@ -70,7 +70,7 @@ func (r *renoState) reduceSlowStartThreshold() {
// were acknowledged.
// Update implements congestionControl.Update.
func (r *renoState) Update(packetsAcked int) {
- if r.s.sndCwnd < r.s.sndSsthresh {
+ if r.s.SndCwnd < r.s.Ssthresh {
packetsAcked = r.updateSlowStart(packetsAcked)
if packetsAcked == 0 {
return
@@ -94,7 +94,7 @@ func (r *renoState) HandleRTOExpired() {
// Reduce the congestion window to 1, i.e., enter slow-start. Per
// RFC 5681, page 7, we must use 1 regardless of the value of the
// initial congestion window.
- r.s.sndCwnd = 1
+ r.s.SndCwnd = 1
}
// PostRecovery implements congestionControl.PostRecovery.
diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go
index 2aa708e97..d368a29fc 100644
--- a/pkg/tcpip/transport/tcp/reno_recovery.go
+++ b/pkg/tcpip/transport/tcp/reno_recovery.go
@@ -31,25 +31,25 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
snd := rr.s
// We are in fast recovery mode. Ignore the ack if it's out of range.
- if !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ if !ack.InRange(snd.SndUna, snd.SndNxt+1) {
return
}
// Don't count this as a duplicate if it is carrying data or
// updating the window.
- if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window {
+ if rcvdSeg.logicalLen() != 0 || snd.SndWnd != rcvdSeg.window {
return
}
// Inflate the congestion window if we're getting duplicate acks
// for the packet we retransmitted.
- if !fastRetransmit && ack == snd.fr.first {
+ if !fastRetransmit && ack == snd.FastRecovery.First {
// We received a dup, inflate the congestion window by 1 packet
// if we're not at the max yet. Only inflate the window if
// regular FastRecovery is in use, RFC6675 does not require
// inflating cwnd on duplicate ACKs.
- if snd.sndCwnd < snd.fr.maxCwnd {
- snd.sndCwnd++
+ if snd.SndCwnd < snd.FastRecovery.MaxCwnd {
+ snd.SndCwnd++
}
return
}
@@ -61,7 +61,7 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
// back onto the wire.
//
// N.B. The retransmit timer will be reset by the caller.
- snd.fr.first = ack
- snd.dupAckCount = 0
+ snd.FastRecovery.First = ack
+ snd.DupAckCount = 0
snd.resendSegment()
}
diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go
index 9d406b0bc..cd860b5e8 100644
--- a/pkg/tcpip/transport/tcp/sack_recovery.go
+++ b/pkg/tcpip/transport/tcp/sack_recovery.go
@@ -42,14 +42,14 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
}
nextSegHint := snd.writeList.Front()
- for snd.outstanding < snd.sndCwnd {
+ for snd.Outstanding < snd.SndCwnd {
var nextSeg *segment
var rescueRtx bool
nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint)
if nextSeg == nil {
return dataSent
}
- if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
+ if !snd.isAssignedSequenceNumber(nextSeg) || snd.SndNxt.LessThanEq(nextSeg.sequenceNumber) {
// New data being sent.
// Step C.3 described below is handled by
@@ -67,7 +67,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
return dataSent
}
dataSent = true
- snd.outstanding++
+ snd.Outstanding++
snd.writeNext = nextSeg.Next()
continue
}
@@ -79,7 +79,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
// "The estimate of the amount of data outstanding in the network
// must be updated by incrementing pipe by the number of octets
// transmitted in (C.1)."
- snd.outstanding++
+ snd.Outstanding++
dataSent = true
snd.sendSegment(nextSeg)
@@ -88,7 +88,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
// We do the last part of rule (4) of NextSeg here to update
// RescueRxt as until this point we don't know if we are going
// to use the rescue transmission.
- snd.fr.rescueRxt = snd.fr.last
+ snd.FastRecovery.RescueRxt = snd.FastRecovery.Last
} else {
// RFC 6675, Step C.2
//
@@ -96,7 +96,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen
// HighData, HighRxt MUST be set to the highest sequence
// number of the retransmitted segment unless NextSeg ()
// rule (4) was invoked for this retransmission."
- snd.fr.highRxt = segEnd - 1
+ snd.FastRecovery.HighRxt = segEnd - 1
}
}
return dataSent
@@ -109,12 +109,12 @@ func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
}
// We are in fast recovery mode. Ignore the ack if it's out of range.
- if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ if ack := rcvdSeg.ackNumber; !ack.InRange(snd.SndUna, snd.SndNxt+1) {
return
}
// RFC 6675 recovery algorithm step C 1-5.
- end := snd.sndUna.Add(snd.sndWnd)
- dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end)
+ end := snd.SndUna.Add(snd.SndWnd)
+ dataSent := sr.handleSACKRecovery(snd.MaxPayloadSize, end)
snd.postXmit(dataSent, true /* shouldScheduleProbe */)
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 8edd6775b..c28641be3 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -236,20 +236,14 @@ func (s *segment) parse(skipChecksumValidation bool) bool {
s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
-
- verifyChecksum := true
if skipChecksumValidation {
s.csumValid = true
- verifyChecksum = false
- }
- if verifyChecksum {
+ } else {
s.csum = s.hdr.Checksum()
- xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr)))
- xsum = s.hdr.CalculateChecksum(xsum)
- xsum = header.ChecksumVV(s.data, xsum)
- s.csumValid = xsum == 0xffff
+ payloadChecksum := header.ChecksumVV(s.data, 0)
+ payloadLength := uint16(s.data.Size())
+ s.csumValid = s.hdr.IsChecksumValid(s.srcAddr, s.dstAddr, payloadChecksum, payloadLength)
}
-
s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
s.ackNumber = seqnum.Value(s.hdr.AckNumber())
s.flags = s.hdr.Flags()
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 54545a1b1..d0d1b0b8a 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool {
func (q *segmentQueue) enqueue(s *segment) bool {
// q.ep.receiveBufferParams() must be called without holding q.mu to
// avoid lock order inversion.
- bufSz := q.ep.receiveBufferSize()
+ bufSz := q.ep.ops.GetReceiveBufferSize()
used := q.ep.receiveMemUsed()
q.mu.Lock()
// Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
// is currently full).
- allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+ allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen
if allow {
q.list.PushBack(s)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index faca35892..cf2e8dcd8 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -85,56 +86,12 @@ type lossRecovery interface {
//
// +stateify savable
type sender struct {
+ stack.TCPSenderState
ep *endpoint
- // lastSendTime is the timestamp when the last packet was sent.
- lastSendTime time.Time `state:".(unixTime)"`
-
- // dupAckCount is the number of duplicated acks received. It is used for
- // fast retransmit.
- dupAckCount int
-
- // fr holds state related to fast recovery.
- fr fastRecovery
-
// lr is the loss recovery algorithm used by the sender.
lr lossRecovery
- // sndCwnd is the congestion window, in packets.
- sndCwnd int
-
- // sndSsthresh is the threshold between slow start and congestion
- // avoidance.
- sndSsthresh int
-
- // sndCAAckCount is the number of packets acknowledged during congestion
- // avoidance. When enough packets have been ack'd (typically cwnd
- // packets), the congestion window is incremented by one.
- sndCAAckCount int
-
- // outstanding is the number of outstanding packets, that is, packets
- // that have been sent but not yet acknowledged.
- outstanding int
-
- // sackedOut is the number of packets which are selectively acked.
- sackedOut int
-
- // sndWnd is the send window size.
- sndWnd seqnum.Size
-
- // sndUna is the next unacknowledged sequence number.
- sndUna seqnum.Value
-
- // sndNxt is the sequence number of the next segment to be sent.
- sndNxt seqnum.Value
-
- // rttMeasureSeqNum is the sequence number being used for the latest RTT
- // measurement.
- rttMeasureSeqNum seqnum.Value
-
- // rttMeasureTime is the time when the rttMeasureSeqNum was sent.
- rttMeasureTime time.Time `state:".(unixTime)"`
-
// firstRetransmittedSegXmitTime is the original transmit time of
// the first segment that was retransmitted due to RTO expiration.
firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
@@ -147,17 +104,15 @@ type sender struct {
// window probes.
unackZeroWindowProbes uint32 `state:"nosave"`
- closed bool
writeNext *segment
writeList segmentList
resendTimer timer `state:"nosave"`
resendWaker sleep.Waker `state:"nosave"`
- // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
- // "round-trip time variation" and "retransmit timeout", as defined in
+ // rtt.TCPRTTState.SRTT and rtt.TCPRTTState.RTTVar are the "smoothed
+ // round-trip time", and "round-trip time variation", as defined in
// section 2 of RFC 6298.
rtt rtt
- rto time.Duration
// minRTO is the minimum permitted value for sender.rto.
minRTO time.Duration
@@ -168,20 +123,9 @@ type sender struct {
// maxRetries is the maximum permitted retransmissions.
maxRetries uint32
- // maxPayloadSize is the maximum size of the payload of a given segment.
- // It is initialized on demand.
- maxPayloadSize int
-
// gso is set if generic segmentation offload is enabled.
gso bool
- // sndWndScale is the number of bits to shift left when reading the send
- // window size from a segment.
- sndWndScale uint8
-
- // maxSentAck is the maxium acknowledgement actually sent.
- maxSentAck seqnum.Value
-
// state is the current state of congestion control for this endpoint.
state tcpip.CongestionControlState
@@ -209,41 +153,7 @@ type sender struct {
type rtt struct {
sync.Mutex `state:"nosave"`
- srtt time.Duration
- rttvar time.Duration
- srttInited bool
-}
-
-// fastRecovery holds information related to fast recovery from a packet loss.
-//
-// +stateify savable
-type fastRecovery struct {
- // active whether the endpoint is in fast recovery. The following fields
- // are only meaningful when active is true.
- active bool
-
- // first and last represent the inclusive sequence number range being
- // recovered.
- first seqnum.Value
- last seqnum.Value
-
- // maxCwnd is the maximum value the congestion window may be inflated to
- // due to duplicate acks. This exists to avoid attacks where the
- // receiver intentionally sends duplicate acks to artificially inflate
- // the sender's cwnd.
- maxCwnd int
-
- // highRxt is the highest sequence number which has been retransmitted
- // during the current loss recovery phase.
- // See: RFC 6675 Section 2 for details.
- highRxt seqnum.Value
-
- // rescueRxt is the highest sequence number which has been
- // optimistically retransmitted to prevent stalling of the ACK clock
- // when there is loss at the end of the window and no new data is
- // available for transmission.
- // See: RFC 6675 Section 2 for details.
- rescueRxt seqnum.Value
+ stack.TCPRTTState
}
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
@@ -253,20 +163,22 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
maxPayloadSize := int(mss) - ep.maxOptionSize()
s := &sender{
- ep: ep,
- sndWnd: sndWnd,
- sndUna: iss + 1,
- sndNxt: iss + 1,
- rto: 1 * time.Second,
- rttMeasureSeqNum: iss + 1,
- lastSendTime: time.Now(),
- maxPayloadSize: maxPayloadSize,
- maxSentAck: irs + 1,
- fr: fastRecovery{
- // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
- last: iss,
- highRxt: iss,
- rescueRxt: iss,
+ ep: ep,
+ TCPSenderState: stack.TCPSenderState{
+ SndWnd: sndWnd,
+ SndUna: iss + 1,
+ SndNxt: iss + 1,
+ RTTMeasureSeqNum: iss + 1,
+ LastSendTime: time.Now(),
+ MaxPayloadSize: maxPayloadSize,
+ MaxSentAck: irs + 1,
+ FastRecovery: stack.TCPFastRecoveryState{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ Last: iss,
+ HighRxt: iss,
+ RescueRxt: iss,
+ },
+ RTO: 1 * time.Second,
},
gso: ep.gso != nil,
}
@@ -282,7 +194,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
if sndWndScale > 0 {
- s.sndWndScale = uint8(sndWndScale)
+ s.SndWndScale = uint8(sndWndScale)
}
s.resendTimer.init(&s.resendWaker)
@@ -294,7 +206,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// Initialize SACK Scoreboard after updating max payload size as we use
// the maxPayloadSize as the smss when determining if a segment is lost
// etc.
- s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+ s.ep.scoreboard = NewSACKScoreboard(uint16(s.MaxPayloadSize), iss)
// Get Stack wide config.
var minRTO tcpip.TCPMinRTOOption
@@ -322,10 +234,10 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
// their initial values.
func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
- s.sndCwnd = InitialCwnd
+ s.SndCwnd = InitialCwnd
// Set sndSsthresh to the maximum int value, which depends on the
// platform.
- s.sndSsthresh = int(^uint(0) >> 1)
+ s.Ssthresh = int(^uint(0) >> 1)
switch congestionControlName {
case ccCubic:
@@ -339,7 +251,7 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon
// initLossRecovery initiates the loss recovery algorithm for the sender.
func (s *sender) initLossRecovery() lossRecovery {
- if s.ep.sackPermitted {
+ if s.ep.SACKPermitted {
return newSACKRecovery(s)
}
return newRenoRecovery(s)
@@ -355,7 +267,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m -= s.ep.maxOptionSize()
// We don't adjust up for now.
- if m >= s.maxPayloadSize {
+ if m >= s.MaxPayloadSize {
return
}
@@ -364,8 +276,8 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
- oldMSS := s.maxPayloadSize
- s.maxPayloadSize = m
+ oldMSS := s.MaxPayloadSize
+ s.MaxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
}
@@ -380,9 +292,9 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// maxPayloadSize.
s.ep.scoreboard.smss = uint16(m)
- s.outstanding -= count
- if s.outstanding < 0 {
- s.outstanding = 0
+ s.Outstanding -= count
+ if s.Outstanding < 0 {
+ s.Outstanding = 0
}
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
@@ -401,10 +313,10 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
nextSeg = seg
}
- if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ if s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
// Update sackedOut for new maximum payload size.
- s.sackedOut -= s.pCount(seg, oldMSS)
- s.sackedOut += s.pCount(seg, s.maxPayloadSize)
+ s.SackedOut -= s.pCount(seg, oldMSS)
+ s.SackedOut += s.pCount(seg, s.MaxPayloadSize)
}
}
@@ -416,32 +328,32 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// sendAck sends an ACK segment.
func (s *sender) sendAck() {
- s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
+ s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.SndNxt)
}
// updateRTO updates the retransmit timeout when a new roud-trip time is
// available. This is done in accordance with section 2 of RFC 6298.
func (s *sender) updateRTO(rtt time.Duration) {
s.rtt.Lock()
- if !s.rtt.srttInited {
- s.rtt.rttvar = rtt / 2
- s.rtt.srtt = rtt
- s.rtt.srttInited = true
+ if !s.rtt.TCPRTTState.SRTTInited {
+ s.rtt.TCPRTTState.RTTVar = rtt / 2
+ s.rtt.TCPRTTState.SRTT = rtt
+ s.rtt.TCPRTTState.SRTTInited = true
} else {
- diff := s.rtt.srtt - rtt
+ diff := s.rtt.TCPRTTState.SRTT - rtt
if diff < 0 {
diff = -diff
}
- // Use RFC6298 standard algorithm to update rttvar and srtt when
+ // Use RFC6298 standard algorithm to update TCPRTTState.RTTVar and TCPRTTState.SRTT when
// no timestamps are available.
- if !s.ep.sendTSOk {
- s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
- s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
+ if !s.ep.SendTSOk {
+ s.rtt.TCPRTTState.RTTVar = (3*s.rtt.TCPRTTState.RTTVar + diff) / 4
+ s.rtt.TCPRTTState.SRTT = (7*s.rtt.TCPRTTState.SRTT + rtt) / 8
} else {
// When we are taking RTT measurements of every ACK then
// we need to use a modified method as specified in
// https://tools.ietf.org/html/rfc7323#appendix-G
- if s.outstanding == 0 {
+ if s.Outstanding == 0 {
s.rtt.Unlock()
return
}
@@ -449,7 +361,7 @@ func (s *sender) updateRTO(rtt time.Duration) {
// terms of packets and not bytes. This is similar to
// how linux also does cwnd and inflight. In practice
// this approximation works as expected.
- expectedSamples := math.Ceil(float64(s.outstanding) / 2)
+ expectedSamples := math.Ceil(float64(s.Outstanding) / 2)
// alpha & beta values are the original values as recommended in
// https://tools.ietf.org/html/rfc6298#section-2.3.
@@ -458,17 +370,17 @@ func (s *sender) updateRTO(rtt time.Duration) {
alphaPrime := alpha / expectedSamples
betaPrime := beta / expectedSamples
- rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
- srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
- s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
- s.rtt.srtt = time.Duration(srtt * float64(time.Second))
+ rttVar := (1-betaPrime)*s.rtt.TCPRTTState.RTTVar.Seconds() + betaPrime*diff.Seconds()
+ srtt := (1-alphaPrime)*s.rtt.TCPRTTState.SRTT.Seconds() + alphaPrime*rtt.Seconds()
+ s.rtt.TCPRTTState.RTTVar = time.Duration(rttVar * float64(time.Second))
+ s.rtt.TCPRTTState.SRTT = time.Duration(srtt * float64(time.Second))
}
}
- s.rto = s.rtt.srtt + 4*s.rtt.rttvar
+ s.RTO = s.rtt.TCPRTTState.SRTT + 4*s.rtt.TCPRTTState.RTTVar
s.rtt.Unlock()
- if s.rto < s.minRTO {
- s.rto = s.minRTO
+ if s.RTO < s.minRTO {
+ s.RTO = s.minRTO
}
}
@@ -476,20 +388,20 @@ func (s *sender) updateRTO(rtt time.Duration) {
func (s *sender) resendSegment() {
// Don't use any segments we already sent to measure RTT as they may
// have been affected by packets being lost.
- s.rttMeasureSeqNum = s.sndNxt
+ s.RTTMeasureSeqNum = s.SndNxt
// Resend the segment.
if seg := s.writeList.Front(); seg != nil {
- if seg.data.Size() > s.maxPayloadSize {
- s.splitSeg(seg, s.maxPayloadSize)
+ if seg.data.Size() > s.MaxPayloadSize {
+ s.splitSeg(seg, s.MaxPayloadSize)
}
// See: RFC 6675 section 5 Step 4.3
//
// To prevent retransmission, set both the HighRXT and RescueRXT
// to the highest sequence number in the retransmitted segment.
- s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
- s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.FastRecovery.HighRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.FastRecovery.RescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
s.sendSegment(seg)
s.ep.stack.Stats().TCP.FastRetransmit.Increment()
s.ep.stats.SendErrors.FastRetransmit.Increment()
@@ -554,15 +466,15 @@ func (s *sender) retransmitTimerExpired() bool {
// Set new timeout. The timer will be restarted by the call to sendData
// below.
- s.rto *= 2
+ s.RTO *= 2
// Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5
- if s.rto > s.maxRTO {
- s.rto = s.maxRTO
+ if s.RTO > s.maxRTO {
+ s.RTO = s.maxRTO
}
// Cap RTO to remaining time.
- if s.rto > remaining {
- s.rto = remaining
+ if s.RTO > remaining {
+ s.RTO = remaining
}
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
@@ -571,9 +483,9 @@ func (s *sender) retransmitTimerExpired() bool {
// After a retransmit timeout, record the highest sequence number
// transmitted in the variable recover, and exit the fast recovery
// procedure if applicable.
- s.fr.last = s.sndNxt - 1
+ s.FastRecovery.Last = s.SndNxt - 1
- if s.fr.active {
+ if s.FastRecovery.Active {
// We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
// has already been updated when entered fast-recovery.
@@ -589,7 +501,7 @@ func (s *sender) retransmitTimerExpired() bool {
//
// We'll keep on transmitting (or retransmitting) as we get acks for
// the data we transmit.
- s.outstanding = 0
+ s.Outstanding = 0
// Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1
//
@@ -663,7 +575,7 @@ func (s *sender) splitSeg(seg *segment, size int) {
// window space.
// ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
// ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
- if seg.data.Size() > s.maxPayloadSize {
+ if seg.data.Size() > s.MaxPayloadSize {
seg.flags ^= header.TCPFlagPsh
}
@@ -689,7 +601,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// transmitted (i.e. either it has no assigned sequence number
// or if it does have one, it's >= the next sequence number
// to be sent [i.e. >= s.sndNxt]).
- if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) {
+ if !s.isAssignedSequenceNumber(seg) || s.SndNxt.LessThanEq(seg.sequenceNumber) {
hint = nil
break
}
@@ -710,7 +622,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// (1.a) S2 is greater than HighRxt
// (1.b) S2 is less than highest octect covered by
// any received SACK.
- if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
+ if s.FastRecovery.HighRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
// NextSeg():
// (1.c) IsLost(S2) returns true.
if s.ep.scoreboard.IsLost(segSeq) {
@@ -743,7 +655,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// unSACKed sequence number SHOULD be returned, and
// RescueRxt set to RecoveryPoint. HighRxt MUST NOT
// be updated.
- if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
+ if s.FastRecovery.RescueRxt.LessThan(s.SndUna - 1) {
if s4 != nil {
if s4.sequenceNumber.LessThan(segSeq) {
s4 = seg
@@ -763,7 +675,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt
// previously unsent data starting with sequence number
// HighData+1 MUST be returned."
for seg := s.writeNext; seg != nil; seg = seg.Next() {
- if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.SndNxt) {
continue
}
// We do not split the segment here to <= smss as it has
@@ -788,7 +700,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(s.sndNxt.Size(end))
+ available := int(s.SndNxt.Size(end))
if available > limit {
available = limit
}
@@ -816,7 +728,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
if !nextTooBig && seg.data.Size() < available {
// Segment is not full.
- if s.outstanding > 0 && s.ep.ops.GetDelayOption() {
+ if s.Outstanding > 0 && s.ep.ops.GetDelayOption() {
// Nagle's algorithm. From Wikipedia:
// Nagle's algorithm works by
// combining a number of small
@@ -835,7 +747,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// send space and MSS.
// TODO(gvisor.dev/issue/2833): Drain the held segments after a
// timeout.
- if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() {
+ if seg.data.Size() < s.MaxPayloadSize && s.ep.ops.GetCorkOption() {
return false
}
}
@@ -843,7 +755,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// Assign flags. We don't do it above so that we can merge
// additional data if Nagle holds the segment.
- seg.sequenceNumber = s.sndNxt
+ seg.sequenceNumber = s.SndNxt
seg.flags = header.TCPFlagAck | header.TCPFlagPsh
}
@@ -893,12 +805,12 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// the segment right here if there are no pending segments. If
// there are pending segments, segment transmits are deferred to
// the retransmit timer handler.
- if s.sndUna != s.sndNxt {
+ if s.SndUna != s.SndNxt {
switch {
case available >= seg.data.Size():
// OK to send, the whole segments fits in the
// receiver's advertised window.
- case available >= s.maxPayloadSize:
+ case available >= s.MaxPayloadSize:
// OK to send, at least 1 MSS sized segment fits
// in the receiver's advertised window.
default:
@@ -918,8 +830,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// If GSO is not in use then cap available to
// maxPayloadSize. When GSO is in use the gVisor GSO logic or
// the host GSO logic will cap the segment to the correct size.
- if s.ep.gso == nil && available > s.maxPayloadSize {
- available = s.maxPayloadSize
+ if s.ep.gso == nil && available > s.MaxPayloadSize {
+ available = s.MaxPayloadSize
}
if seg.data.Size() > available {
@@ -933,8 +845,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// Update sndNxt if we actually sent new data (as opposed to
// retransmitting some previously sent data).
- if s.sndNxt.LessThan(segEnd) {
- s.sndNxt = segEnd
+ if s.SndNxt.LessThan(segEnd) {
+ s.SndNxt = segEnd
}
return true
@@ -945,9 +857,9 @@ func (s *sender) sendZeroWindowProbe() {
s.unackZeroWindowProbes++
// Send a zero window probe with sequence number pointing to
// the last acknowledged byte.
- s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win)
+ s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.SndUna-1, ack, win)
// Rearm the timer to continue probing.
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
func (s *sender) enableZeroWindowProbing() {
@@ -958,7 +870,7 @@ func (s *sender) enableZeroWindowProbing() {
if s.firstRetransmittedSegXmitTime.IsZero() {
s.firstRetransmittedSegXmitTime = time.Now()
}
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
func (s *sender) disableZeroWindowProbing() {
@@ -978,12 +890,12 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) {
// If the sender has advertized zero receive window and we have
// data to be sent out, start zero window probing to query the
// the remote for it's receive window size.
- if s.writeNext != nil && s.sndWnd == 0 {
+ if s.writeNext != nil && s.SndWnd == 0 {
s.enableZeroWindowProbing()
}
// If we have no more pending data, start the keepalive timer.
- if s.sndUna == s.sndNxt {
+ if s.SndUna == s.SndNxt {
s.ep.resetKeepaliveTimer(false)
} else {
// Enable timers if we have pending data.
@@ -992,10 +904,10 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) {
s.schedulePTO()
} else if !s.resendTimer.enabled() {
s.probeTimer.disable()
- if s.outstanding > 0 {
+ if s.Outstanding > 0 {
// Enable the resend timer if it's not enabled yet and there is
// outstanding data.
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
}
}
@@ -1004,29 +916,29 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) {
// sendData sends new data segments. It is called when data becomes available or
// when the send window opens up.
func (s *sender) sendData() {
- limit := s.maxPayloadSize
+ limit := s.MaxPayloadSize
if s.gso {
limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
}
- end := s.sndUna.Add(s.sndWnd)
+ end := s.SndUna.Add(s.SndWnd)
// Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
- if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
- if s.sndCwnd > InitialCwnd {
- s.sndCwnd = InitialCwnd
+ if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO {
+ if s.SndCwnd > InitialCwnd {
+ s.SndCwnd = InitialCwnd
}
}
var dataSent bool
- for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
- cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ for seg := s.writeNext; seg != nil && s.Outstanding < s.SndCwnd; seg = seg.Next() {
+ cwndLimit := (s.SndCwnd - s.Outstanding) * s.MaxPayloadSize
if cwndLimit < limit {
limit = cwndLimit
}
- if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ if s.isAssignedSequenceNumber(seg) && s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
// Move writeNext along so that we don't try and scan data that
// has already been SACKED.
s.writeNext = seg.Next()
@@ -1036,7 +948,7 @@ func (s *sender) sendData() {
break
}
dataSent = true
- s.outstanding += s.pCount(seg, s.maxPayloadSize)
+ s.Outstanding += s.pCount(seg, s.MaxPayloadSize)
s.writeNext = seg.Next()
}
@@ -1044,21 +956,21 @@ func (s *sender) sendData() {
}
func (s *sender) enterRecovery() {
- s.fr.active = true
+ s.FastRecovery.Active = true
// Save state to reflect we're now in fast recovery.
//
// See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
- s.sndCwnd = s.sndSsthresh + 3
- s.sackedOut = 0
- s.dupAckCount = 0
- s.fr.first = s.sndUna
- s.fr.last = s.sndNxt - 1
- s.fr.maxCwnd = s.sndCwnd + s.outstanding
- s.fr.highRxt = s.sndUna
- s.fr.rescueRxt = s.sndUna
- if s.ep.sackPermitted {
+ s.SndCwnd = s.Ssthresh + 3
+ s.SackedOut = 0
+ s.DupAckCount = 0
+ s.FastRecovery.First = s.SndUna
+ s.FastRecovery.Last = s.SndNxt - 1
+ s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding
+ s.FastRecovery.HighRxt = s.SndUna
+ s.FastRecovery.RescueRxt = s.SndUna
+ if s.ep.SACKPermitted {
s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
// Set TLPRxtOut to false according to
@@ -1075,12 +987,12 @@ func (s *sender) enterRecovery() {
}
func (s *sender) leaveRecovery() {
- s.fr.active = false
- s.fr.maxCwnd = 0
- s.dupAckCount = 0
+ s.FastRecovery.Active = false
+ s.FastRecovery.MaxCwnd = 0
+ s.DupAckCount = 0
// Deflate cwnd. It had been artificially inflated when new dups arrived.
- s.sndCwnd = s.sndSsthresh
+ s.SndCwnd = s.Ssthresh
s.cc.PostRecovery()
}
@@ -1099,7 +1011,7 @@ func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
func (s *sender) SetPipe() {
// If SACK isn't permitted or it is permitted but recovery is not active
// then ignore pipe calculations.
- if !s.ep.sackPermitted || !s.fr.active {
+ if !s.ep.SACKPermitted || !s.FastRecovery.Active {
return
}
pipe := 0
@@ -1119,7 +1031,7 @@ func (s *sender) SetPipe() {
// After initializing pipe to zero, the following steps are
// taken for each octet 'S1' in the sequence space between
// HighACK and HighData that has not been SACKed:
- if !s1.sequenceNumber.LessThan(s.sndNxt) {
+ if !s1.sequenceNumber.LessThan(s.SndNxt) {
break
}
if s.ep.scoreboard.IsSACKED(sb) {
@@ -1138,20 +1050,20 @@ func (s *sender) SetPipe() {
}
// SetPipe():
// (b) If S1 <= HighRxt, Pipe is incremented by 1.
- if s1.sequenceNumber.LessThanEq(s.fr.highRxt) {
+ if s1.sequenceNumber.LessThanEq(s.FastRecovery.HighRxt) {
pipe++
}
}
}
- s.outstanding = pipe
+ s.Outstanding = pipe
}
// shouldEnterRecovery returns true if the sender should enter fast recovery
// based on dupAck count and sack scoreboard.
// See RFC 6675 section 5.
func (s *sender) shouldEnterRecovery() bool {
- return s.dupAckCount >= nDupAckThreshold ||
- (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna))
+ return s.DupAckCount >= nDupAckThreshold ||
+ (s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.SndUna))
}
// detectLoss is called when an ack is received and returns whether a loss is
@@ -1163,24 +1075,24 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// If RACK is enabled and there is no reordering we should honor the
// three duplicate ACK rule to enter recovery.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4
- if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
- if s.rc.reorderSeen {
+ if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ if s.rc.Reord {
return false
}
}
if !s.isDupAck(seg) {
- s.dupAckCount = 0
+ s.DupAckCount = 0
return false
}
- s.dupAckCount++
+ s.DupAckCount++
// Do not enter fast recovery until we reach nDupAckThreshold or the
// first unacknowledged byte is considered lost as per SACK scoreboard.
if !s.shouldEnterRecovery() {
// RFC 6675 Step 3.
- s.fr.highRxt = s.sndUna - 1
+ s.FastRecovery.HighRxt = s.SndUna - 1
// Do run SetPipe() to calculate the outstanding segments.
s.SetPipe()
s.state = tcpip.Disorder
@@ -1196,8 +1108,8 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// Note that we only enter recovery when at least one more byte of data
// beyond s.fr.last (the highest byte that was outstanding when fast
// retransmit was last entered) is acked.
- if !s.fr.last.LessThan(seg.ackNumber - 1) {
- s.dupAckCount = 0
+ if !s.FastRecovery.Last.LessThan(seg.ackNumber - 1) {
+ s.DupAckCount = 0
return false
}
s.cc.HandleLossDetected()
@@ -1212,22 +1124,22 @@ func (s *sender) isDupAck(seg *segment) bool {
// can leverage the SACK information to determine when an incoming ACK is a
// "duplicate" (e.g., if the ACK contains previously unknown SACK
// information).
- if s.ep.sackPermitted && !seg.hasNewSACKInfo {
+ if s.ep.SACKPermitted && !seg.hasNewSACKInfo {
return false
}
// (a) The receiver of the ACK has outstanding data.
- return s.sndUna != s.sndNxt &&
+ return s.SndUna != s.SndNxt &&
// (b) The incoming acknowledgment carries no data.
seg.logicalLen() == 0 &&
// (c) The SYN and FIN bits are both off.
!seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) &&
// (d) the ACK number is equal to the greatest acknowledgment received on
// the given connection (TCP.UNA from RFC793).
- seg.ackNumber == s.sndUna &&
+ seg.ackNumber == s.SndUna &&
// (e) the advertised window in the incoming acknowledgment equals the
// advertised window in the last incoming acknowledgment.
- s.sndWnd == seg.window
+ s.SndWnd == seg.window
}
// Iterate the writeList and update RACK for each segment which is newly acked
@@ -1267,7 +1179,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
seg.acked = true
- s.sackedOut += s.pCount(seg, s.maxPayloadSize)
+ s.SackedOut += s.pCount(seg, s.MaxPayloadSize)
}
seg = seg.Next()
}
@@ -1322,18 +1234,18 @@ func checkDSACK(rcvdSeg *segment) bool {
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
- s.updateRTO(time.Now().Sub(s.rttMeasureTime))
- s.rttMeasureSeqNum = s.sndNxt
+ if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
+ s.updateRTO(time.Now().Sub(s.RTTMeasureTime))
+ s.RTTMeasureSeqNum = s.SndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS {
- s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber)
+ if s.ep.SendTSOk && rcvdSeg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.MaxSentAck, rcvdSeg.sequenceNumber)
}
// Insert SACKBlock information into our scoreboard.
- if s.ep.sackPermitted {
+ if s.ep.SACKPermitted {
for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
// true:
@@ -1347,7 +1259,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// NOTE: This check specifically excludes DSACK blocks
// which have start/end before sndUna and are used to
// indicate spurious retransmissions.
- if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ if rcvdSeg.ackNumber.LessThan(sb.Start) && s.SndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.SndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
s.ep.scoreboard.Insert(sb)
rcvdSeg.hasNewSACKInfo = true
}
@@ -1375,10 +1287,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ack := rcvdSeg.ackNumber
fastRetransmit := false
// Do not leave fast recovery, if the ACK is out of range.
- if s.fr.active {
+ if s.FastRecovery.Active {
// Leave fast recovery if it acknowledges all the data covered by
// this fast recovery session.
- if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) {
+ if (ack-1).InRange(s.SndUna, s.SndNxt) && s.FastRecovery.Last.LessThan(ack) {
s.leaveRecovery()
}
} else {
@@ -1392,28 +1304,28 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Stash away the current window size.
- s.sndWnd = rcvdSeg.window
+ s.SndWnd = rcvdSeg.window
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
// unacknowledged segment.
if s.zeroWindowProbing && rcvdSeg.window > 0 &&
- (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
+ (ack == s.SndUna || (ack-1).InRange(s.SndUna, s.SndNxt)) {
s.disableZeroWindowProbing()
}
// On receiving the ACK for the zero window probe, account for it and
// skip trying to send any segment as we are still probing for
// receive window to become non-zero.
- if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna {
+ if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.SndUna {
s.unackZeroWindowProbes--
return
}
// Ignore ack if it doesn't acknowledge any new data.
- if (ack - 1).InRange(s.sndUna, s.sndNxt) {
- s.dupAckCount = 0
+ if (ack - 1).InRange(s.SndUna, s.SndNxt) {
+ s.DupAckCount = 0
// See : https://tools.ietf.org/html/rfc1323#section-3.3.
// Specifically we should only update the RTO using TSEcr if the
@@ -1423,7 +1335,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
- if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
+ if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
@@ -1438,12 +1350,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// When an ack is received we must rearm the timer.
// RFC 6298 5.3
s.probeTimer.disable()
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
// Remove all acknowledged data from the write list.
- acked := s.sndUna.Size(ack)
- s.sndUna = ack
+ acked := s.SndUna.Size(ack)
+ s.SndUna = ack
// The remote ACK-ing at least 1 byte is an indication that we have a
// full-duplex connection to the remote as the only way we will receive an
@@ -1457,7 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
ackLeft := acked
- originalOutstanding := s.outstanding
+ originalOutstanding := s.Outstanding
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1466,10 +1378,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg, s.maxPayloadSize)
+ prevCount := s.pCount(seg, s.MaxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
+ s.Outstanding -= prevCount - s.pCount(seg, s.MaxPayloadSize)
break
}
@@ -1478,7 +1390,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Update the RACK fields if SACK is enabled.
- if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ if s.ep.SACKPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
}
@@ -1488,10 +1400,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
- if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ if !s.ep.SACKPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ s.Outstanding -= s.pCount(seg, s.MaxPayloadSize)
} else {
- s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
+ s.SackedOut -= s.pCount(seg, s.MaxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
@@ -1501,13 +1413,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.ep.updateSndBufferUsage(int(acked))
// Clear SACK information for all acked data.
- s.ep.scoreboard.Delete(s.sndUna)
+ s.ep.scoreboard.Delete(s.SndUna)
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
- if !s.fr.active {
- s.cc.Update(originalOutstanding - s.outstanding)
- if s.fr.last.LessThan(s.sndUna) {
+ if !s.FastRecovery.Active {
+ s.cc.Update(originalOutstanding - s.Outstanding)
+ if s.FastRecovery.Last.LessThan(s.SndUna) {
s.state = tcpip.Open
// Update RACK when we are exiting fast or RTO
// recovery as described in the RFC
@@ -1522,16 +1434,16 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
// get an ack that cover previously sent data.
- if s.outstanding < 0 {
- s.outstanding = 0
+ if s.Outstanding < 0 {
+ s.Outstanding = 0
}
s.SetPipe()
// If all outstanding data was acknowledged the disable the timer.
// RFC 6298 Rule 5.3
- if s.sndUna == s.sndNxt {
- s.outstanding = 0
+ if s.SndUna == s.SndNxt {
+ s.Outstanding = 0
// Reset firstRetransmittedSegXmitTime to the zero value.
s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
@@ -1539,7 +1451,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
- if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
// Update RACK reorder window.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// * Upon receiving an ACK:
@@ -1549,7 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// After the reorder window is calculated, detect any loss by checking
// if the time elapsed after the segments are sent is greater than the
// reorder window.
- if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active {
+ if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.FastRecovery.Active {
// If any segment is marked as lost by
// RACK, enter recovery and retransmit
// the lost segments.
@@ -1558,19 +1470,19 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
fastRetransmit = true
}
- if s.fr.active {
+ if s.FastRecovery.Active {
s.rc.DoRecovery(nil, fastRetransmit)
}
}
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
- if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 {
+ if s.FastRecovery.Active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 {
s.lr.DoRecovery(rcvdSeg, fastRetransmit)
// When SACK is enabled data sending is governed by steps in
// RFC 6675 Section 5 recovery steps A-C.
// See: https://tools.ietf.org/html/rfc6675#section-5.
- if s.ep.sackPermitted {
+ if s.ep.SACKPermitted {
return
}
}
@@ -1587,7 +1499,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
if seg.xmitCount > 0 {
s.ep.stack.Stats().TCP.Retransmits.Increment()
s.ep.stats.SendErrors.Retransmits.Increment()
- if s.sndCwnd < s.sndSsthresh {
+ if s.SndCwnd < s.Ssthresh {
s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
}
}
@@ -1601,11 +1513,11 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// then use the conservative timer described in RFC6675 Section 6.0,
// otherwise follow the standard time described in RFC6298 Section 5.1.
if err != nil && seg.data.Size() != 0 {
- if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted {
- s.resendTimer.enable(s.rto)
+ if s.FastRecovery.Active && seg.xmitCount > 1 && s.ep.SACKPermitted {
+ s.resendTimer.enable(s.RTO)
} else {
if !s.resendTimer.enabled() {
- s.resendTimer.enable(s.rto)
+ s.resendTimer.enable(s.RTO)
}
}
}
@@ -1616,15 +1528,15 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
- s.lastSendTime = time.Now()
- if seq == s.rttMeasureSeqNum {
- s.rttMeasureTime = s.lastSendTime
+ s.LastSendTime = time.Now()
+ if seq == s.RTTMeasureSeqNum {
+ s.RTTMeasureTime = s.LastSendTime
}
rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
// Remember the max sent ack.
- s.maxSentAck = rcvNxt
+ s.MaxSentAck = rcvNxt
return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
index ba41cff6d..2f805d8ce 100644
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -24,26 +24,6 @@ type unixTime struct {
nano int64
}
-// saveLastSendTime is invoked by stateify.
-func (s *sender) saveLastSendTime() unixTime {
- return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()}
-}
-
-// loadLastSendTime is invoked by stateify.
-func (s *sender) loadLastSendTime(unix unixTime) {
- s.lastSendTime = time.Unix(unix.second, unix.nano)
-}
-
-// saveRttMeasureTime is invoked by stateify.
-func (s *sender) saveRttMeasureTime() unixTime {
- return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()}
-}
-
-// loadRttMeasureTime is invoked by stateify.
-func (s *sender) loadRttMeasureTime(unix unixTime) {
- s.rttMeasureTime = time.Unix(unix.second, unix.nano)
-}
-
// afterLoad is invoked by stateify.
func (s *sender) afterLoad() {
s.resendTimer.init(&s.resendWaker)
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 5cdd5b588..c58361bc1 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -33,6 +33,7 @@ const (
tsOptionSize = 12
maxTCPOptionSize = 40
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
+ latency = 5 * time.Millisecond
)
func setStackRACKPermitted(t *testing.T, c *context.Context) {
@@ -182,6 +183,9 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
for i := 0; i < numPackets; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
+ // This delay is added to increase RTT as low RTT can cause TLP
+ // before sending ACK.
+ time.Sleep(latency)
}
return data
@@ -479,7 +483,7 @@ func TestRACKOnePacketTailLoss(t *testing.T) {
}{
// #3 was retransmitted as TLP.
{tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
{tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
// RTO should not have fired.
{tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
@@ -852,8 +856,8 @@ func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan
return
}
- if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT {
- probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT)
+ if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT {
+ probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT)
return
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 81f800cad..20c9761f2 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -160,12 +160,9 @@ func TestSackPermittedAccept(t *testing.T) {
defer c.Cleanup()
if tc.cookieEnabled {
- // Set the SynRcvd threshold to
- // zero to force a syn cookie
- // based accept to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
setStackSACKPermitted(t, c, sackEnabled)
@@ -235,12 +232,9 @@ func TestSackDisabledAccept(t *testing.T) {
defer c.Cleanup()
if tc.cookieEnabled {
- // Set the SynRcvd threshold to
- // zero to force a syn cookie
- // based accept to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9c23469f2..9f29a48fb 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
"gvisor.dev/gvisor/pkg/test/testutil"
@@ -929,10 +930,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
}
// Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err)
- }
+ rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize()
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
@@ -955,11 +953,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
// when completing the handshake for a new TCP connection from a TCP
// listening socket. It should be present in the sent TCP SYN-ACK segment.
func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
- const (
- nonSynCookieAccepts = 2
- totalAccepts = 4
- mtu = 5000
- )
+ const mtu = 5000
ips := []struct {
name string
@@ -1033,12 +1027,6 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
ip.createEP(c)
- // Set the SynRcvd threshold to force a syn cookie based accept to happen.
- opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
}
@@ -1048,13 +1036,17 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
t.Fatalf("Bind(%+v): %s:", bindAddr, err)
}
- if err := c.EP.Listen(totalAccepts); err != nil {
- t.Fatalf("Listen(%d): %s:", totalAccepts, err)
+ backlog := 5
+ // Keep the number of client requests twice to the backlog
+ // such that half of the connections do not use syncookies
+ // and the other half does.
+ clientConnects := backlog * 2
+
+ if err := c.EP.Listen(backlog); err != nil {
+ t.Fatalf("Listen(%d): %s:", backlog, err)
}
- // The first nonSynCookieAccepts packets sent will trigger a gorooutine
- // based accept. The rest will trigger a cookie based accept.
- for i := 0; i < totalAccepts; i++ {
+ for i := 0; i < clientConnects; i++ {
// Send a SYN requests.
iss := seqnum.Value(i)
srcPort := context.TestPort + uint16(i)
@@ -1297,6 +1289,98 @@ func TestListenShutdown(t *testing.T) {
))
}
+var _ waiter.EntryCallback = (callback)(nil)
+
+type callback func(*waiter.Entry, waiter.EventMask)
+
+func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) {
+ cb(entry, mask)
+}
+
+func TestListenerReadinessOnEvent(t *testing.T) {
+ s := stack.New(stack.Options{
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ {
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+ const id = 1
+ if err := s.CreateNIC(id, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
+ }
+ if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
+ t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: header.IPv4EmptySubnet, NIC: id},
+ })
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil {
+ t.Fatalf("Bind(%s): %s", context.StackAddr, err)
+ }
+ const backlog = 1
+ if err := ep.Listen(backlog); err != nil {
+ t.Fatalf("Listen(%d): %s", backlog, err)
+ }
+
+ address, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress(): %s", err)
+ }
+
+ conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
+ }
+ defer conn.Close()
+
+ events := make(chan waiter.EventMask)
+ // Scope `entry` to allow a binding of the same name below.
+ {
+ entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) {
+ events <- ep.Readiness(mask)
+ })}
+ wq.EventRegister(&entry, waiter.EventIn)
+ defer wq.EventUnregister(&entry)
+ }
+
+ entry, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&entry, waiter.EventOut)
+ defer wq.EventUnregister(&entry)
+
+ switch err := conn.Connect(address).(type) {
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect(%#v): %v", address, err)
+ }
+
+ // Read at least one event.
+ got := <-events
+ for {
+ select {
+ case event := <-events:
+ got |= event
+ continue
+ case <-ch:
+ if want := waiter.ReadableEvents; got != want {
+ t.Errorf("observed events = %b, want %b", got, want)
+ }
+ }
+ break
+ }
+}
+
// TestListenCloseWhileConnect tests for the listening endpoint to
// drain the accept-queue when closed. This should reset all of the
// pending connections that are waiting to be accepted.
@@ -1993,9 +2077,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true)
// Keep the payload size < segment overhead and such that it is a multiple
// of the window scaled value. This enables the test to perform equality
@@ -2115,9 +2197,7 @@ func TestNoWindowShrinking(t *testing.T) {
initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
// Now shrink the receive buffer to half its original size.
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true)
data := generateRandomPayload(t, rcvBufSize)
// Send a payload of half the size of rcvBufSize.
@@ -2373,9 +2453,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
- }
+ ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -2447,9 +2525,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
- }
+ ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -3042,9 +3118,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
- }
+ ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -3087,11 +3161,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, mtu)
defer c.Cleanup()
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- opt := tcpip.TCPSynRcvdCountThresholdOption(0)
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create EP and start listening.
@@ -3185,9 +3257,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 3
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
@@ -4411,11 +4481,7 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOpt failed: %s", err)
- }
-
+ s := ep.SocketOptions().GetReceiveBufferSize()
if int(s) != v {
t.Fatalf("got receive buffer size = %d, want = %d", s, v)
}
@@ -4521,10 +4587,7 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min/2.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
- }
-
+ ep.SocketOptions().SetReceiveBufferSize(99, true)
checkRecvBufferSize(t, ep, 200)
ep.SocketOptions().SetSendBufferSize(149, true)
@@ -4532,15 +4595,11 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
-
+ ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true)
// Values above max are capped at max and then doubled.
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true)
-
// Values above max are capped at max and then doubled.
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
}
@@ -4814,7 +4873,13 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("unknown address type: '%s'", candidateAddressType)
}
- start, end := s.PortRange()
+ const (
+ start = 16000
+ end = 16050
+ )
+ if err := s.SetPortRange(start, end); err != nil {
+ t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
+ }
for i := start; i <= end; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
@@ -5363,7 +5428,7 @@ func TestListenBacklogFull(t *testing.T) {
}
lastPortOffset := uint16(0)
- for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
+ for ; int(lastPortOffset) < listenBacklog+1; lastPortOffset++ {
executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
}
@@ -5445,8 +5510,8 @@ func TestListenBacklogFull(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
// non unicast IPv4 address are not accepted.
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
- multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
- otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ multicastAddr := tcpiptestutil.MustParse4("224.0.1.2")
+ otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3")
subnet := context.StackAddrWithPrefix.Subnet()
subnetBroadcastAddr := subnet.Broadcast()
@@ -5557,8 +5622,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
// non unicast IPv6 address are not accepted.
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
- otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+ multicastAddr := tcpiptestutil.MustParse6("ff0e::101")
+ otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102")
tests := []struct {
name string
@@ -5671,15 +5736,13 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
// Test acceptance.
- // Start listening.
- listenBacklog := 1
- if err := c.EP.Listen(listenBacklog); err != nil {
+ if err := c.EP.Listen(0); err != nil {
t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
// second one should not get any response and is dropped as
- // the synRcvd count will be equal to backlog.
+ // the accept queue is full.
irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5701,23 +5764,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- //
- // NOTE: we did not complete the handshake for the previous one so the
- // accept backlog should be empty and there should be one connection in
- // synRcvd state.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(889),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Now complete the previous connection and verify that there is a connection
- // to accept.
+ // Now complete the previous connection.
// Send ACK.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5728,11 +5775,24 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
RcvWnd: 30000,
})
- // Try to accept the connections in the backlog.
+ // Verify if that is delivered to the accept queue.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.ReadableEvents)
defer c.WQ.EventUnregister(&we)
+ <-ch
+
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 1,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(889),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+ // Try to accept the connections in the backlog.
newEP, _, err := c.EP.Accept(nil)
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
@@ -5764,11 +5824,6 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.TCPSynRcvdCountThresholdOption(1)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
// Create TCP endpoint.
var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -5781,9 +5836,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
t.Fatalf("Bind failed: %s", err)
}
- // Start listening.
- listenBacklog := 1
- if err := c.EP.Listen(listenBacklog); err != nil {
+ // Test for SynCookies usage after filling up the backlog.
+ if err := c.EP.Listen(0); err != nil {
t.Fatalf("Listen failed: %s", err)
}
@@ -6066,7 +6120,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
- if err := c.EP.Listen(1); err != nil {
+ if err := c.EP.Listen(0); err != nil {
t.Fatalf("Listen failed: %s", err)
}
@@ -7553,8 +7607,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
- c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
-
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 2949588ce..1deb1fe4d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -139,9 +139,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
defer c.Cleanup()
if cookieEnabled {
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -202,9 +202,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
defer c.Cleanup()
if cookieEnabled {
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e73f90bb0..7578d64ec 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -757,9 +757,7 @@ func (c *Context) Create(epRcvBuf int) {
}
if epRcvBuf != -1 {
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */)
}
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 153e8c950..dd5c910ae 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -56,6 +56,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 956da0e0c..c9f2f3efc 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "fmt"
"io"
"sync/atomic"
@@ -89,12 +88,11 @@ type endpoint struct {
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvReady bool
- rcvList udpPacketList
- rcvBufSizeMax int `state:".(int)"`
- rcvBufSize int
- rcvClosed bool
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSize int
+ rcvClosed bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
@@ -144,6 +142,10 @@ type endpoint struct {
// ops is used to get socket level options.
ops tcpip.SocketOptions
+
+ // frozen indicates if the packets should be delivered to the endpoint
+ // during restore.
+ frozen bool
}
// +stateify savable
@@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- rcvBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
+ e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if err := s.Option(&rs); err == nil {
- e.rcvBufSizeMax = rs.Default
+ e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
}
return e
@@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
e.mu.Lock()
e.sendTOS = uint8(v)
e.mu.Unlock()
-
- case tcpip.ReceiveBufferSizeOption:
- // Make sure the receive buffer size is within the min and max
- // allowed.
- var rs stack.ReceiveBufferSizeOption
- if err := e.stack.Option(&rs); err != nil {
- panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
- }
-
- if v < rs.Min {
- v = rs.Min
- }
- if v > rs.Max {
- v = rs.Max
- }
-
- e.mu.Lock()
- e.rcvBufSizeMax = v
- e.mu.Unlock()
- return nil
}
return nil
@@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- v := e.rcvBufSizeMax
- e.rcvMu.Unlock()
- return v, nil
-
case tcpip.TTLOption:
e.mu.Lock()
v := int(e.ttl)
@@ -1255,20 +1231,29 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
-// On IPv4, UDP checksum is optional, and a zero value means the transmitter
-// omitted the checksum generation (RFC768).
-// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
- if !pkt.RXTransportChecksumValidated &&
- (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) {
- netHdr := pkt.Network()
- xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length())
- for _, v := range pkt.Data().Views() {
- xsum = header.Checksum(v, xsum)
- }
- return hdr.CalculateChecksum(xsum) == 0xffff
+ if pkt.RXTransportChecksumValidated {
+ return true
+ }
+
+ // On IPv4, UDP checksum is optional, and a zero value means the transmitter
+ // omitted the checksum generation, as per RFC 768:
+ //
+ // An all zero transmitted checksum value means that the transmitter
+ // generated no checksum (for debugging or for higher level protocols that
+ // don't care).
+ //
+ // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
+ //
+ // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
+ // checksum is not optional.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 {
+ return true
}
- return true
+
+ netHdr := pkt.Network()
+ payloadChecksum := pkt.Data().AsRange().Checksum()
+ return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum)
}
// HandlePacket is called by the stack when new packets arrive to this transport
@@ -1284,7 +1269,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
if !verifyChecksum(hdr, pkt) {
- // Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
e.stats.ReceiveErrors.ChecksumErrors.Increment()
return
@@ -1302,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- if e.rcvBufSize >= e.rcvBufSizeMax {
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
e.rcvMu.Unlock()
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -1436,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// freeze prevents any more packets from being delivered to the endpoint.
+func (e *endpoint) freeze() {
+ e.mu.Lock()
+ e.frozen = true
+ e.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
+// new packets to be delivered again.
+func (e *endpoint) thaw() {
+ e.mu.Lock()
+ e.frozen = false
+ e.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 21a6aa460..4aba68b21 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) {
u.data = data
}
-// beforeSave is invoked by stateify.
-func (e *endpoint) beforeSave() {
- // Stop incoming packets from being handled (and mutate endpoint state).
- // The lock will be released after savercvBufSizeMax(), which would have
- // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
- // packets.
- e.rcvMu.Lock()
-}
-
-// saveRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) saveRcvBufSizeMax() int {
- max := e.rcvBufSizeMax
- // Make sure no new packets will be handled regardless of the lock.
- e.rcvBufSizeMax = 0
- // Release the lock acquired in beforeSave() so regular endpoint closing
- // logic can proceed after save.
- e.rcvMu.Unlock()
- return max
-}
-
-// loadRcvBufSizeMax is invoked by stateify.
-func (e *endpoint) loadRcvBufSizeMax(max int) {
- e.rcvBufSizeMax = max
-}
-
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ e.freeze()
+}
+
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.thaw()
+
e.mu.Lock()
defer e.mu.Unlock()
e.stack = s
- e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 77ca70a04..dc2e3f493 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -34,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -2364,7 +2365,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
ipv4Subnet := ipv4Addr.Subnet()
ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4Gateway := testutil.MustParse4("192.168.1.1")
ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
Address: "\xc0\xa8\x01\x3a",
PrefixLen: 31,
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
index 7f983a0b3..366f068e3 100644
--- a/pkg/test/dockerutil/BUILD
+++ b/pkg/test/dockerutil/BUILD
@@ -36,8 +36,8 @@ go_test(
tags = [
# Requires docker and runsc to be configured before test runs.
# Also requires the test to be run as root.
- "manual",
"local",
+ "manual",
],
visibility = ["//:sandbox"],
)
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 41fcf4978..06152a444 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -434,7 +434,14 @@ func (c *Container) Wait(ctx context.Context) error {
select {
case err := <-errChan:
return err
- case <-statusChan:
+ case res := <-statusChan:
+ if res.StatusCode != 0 {
+ var msg string
+ if res.Error != nil {
+ msg = res.Error.Message
+ }
+ return fmt.Errorf("container returned non-zero status: %d, msg: %q", res.StatusCode, msg)
+ }
return nil
}
}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index 054269b59..3dba36f12 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -1,42 +1,22 @@
load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-go_template_instance(
- name = "addr_range",
- out = "addr_range.go",
- package = "usermem",
- prefix = "Addr",
- template = "//pkg/segment:generic_range",
- types = {
- "T": "Addr",
- },
-)
-
go_library(
name = "usermem",
srcs = [
- "access_type.go",
- "addr.go",
- "addr_range.go",
- "addr_range_seq_unsafe.go",
"bytes_io.go",
"bytes_io_unsafe.go",
"usermem.go",
- "usermem_arm64.go",
- "usermem_x86.go",
],
visibility = ["//:sandbox"],
deps = [
"//pkg/atomicbitops",
- "//pkg/binary",
"//pkg/context",
"//pkg/gohacks",
- "//pkg/log",
+ "//pkg/hostarch",
"//pkg/safemem",
"//pkg/syserror",
- "@org_golang_x_sys//unix:go_default_library",
],
)
@@ -44,12 +24,12 @@ go_test(
name = "usermem_test",
size = "small",
srcs = [
- "addr_range_seq_test.go",
"usermem_test.go",
],
library = ":usermem",
deps = [
"//pkg/context",
+ "//pkg/hostarch",
"//pkg/safemem",
"//pkg/syserror",
],
diff --git a/pkg/usermem/bytes_io.go b/pkg/usermem/bytes_io.go
index e177d30eb..3da3c0294 100644
--- a/pkg/usermem/bytes_io.go
+++ b/pkg/usermem/bytes_io.go
@@ -16,6 +16,7 @@ package usermem
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -30,7 +31,7 @@ type BytesIO struct {
}
// CopyOut implements IO.CopyOut.
-func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error) {
+func (b *BytesIO) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts IOOpts) (int, error) {
rngN, rngErr := b.rangeCheck(addr, len(src))
if rngN == 0 {
return 0, rngErr
@@ -39,7 +40,7 @@ func (b *BytesIO) CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpt
}
// CopyIn implements IO.CopyIn.
-func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error) {
+func (b *BytesIO) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts IOOpts) (int, error) {
rngN, rngErr := b.rangeCheck(addr, len(dst))
if rngN == 0 {
return 0, rngErr
@@ -48,7 +49,7 @@ func (b *BytesIO) CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts
}
// ZeroOut implements IO.ZeroOut.
-func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) {
+func (b *BytesIO) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) {
if toZero > int64(maxInt) {
return 0, syserror.EINVAL
}
@@ -64,7 +65,7 @@ func (b *BytesIO) ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOO
}
// CopyOutFrom implements IO.CopyOutFrom.
-func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) {
+func (b *BytesIO) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) {
dsts, rngErr := b.blocksFromAddrRanges(ars)
n, err := src.ReadToBlocks(dsts)
if err != nil {
@@ -74,7 +75,7 @@ func (b *BytesIO) CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem
}
// CopyInTo implements IO.CopyInTo.
-func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) {
+func (b *BytesIO) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) {
srcs, rngErr := b.blocksFromAddrRanges(ars)
n, err := dst.WriteFromBlocks(srcs)
if err != nil {
@@ -83,14 +84,14 @@ func (b *BytesIO) CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Wr
return int64(n), rngErr
}
-func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) {
+func (b *BytesIO) rangeCheck(addr hostarch.Addr, length int) (int, error) {
if length == 0 {
return 0, nil
}
if length < 0 {
return 0, syserror.EINVAL
}
- max := Addr(len(b.Bytes))
+ max := hostarch.Addr(len(b.Bytes))
if addr >= max {
return 0, syserror.EFAULT
}
@@ -101,7 +102,7 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) {
return length, nil
}
-func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) {
+func (b *BytesIO) blocksFromAddrRanges(ars hostarch.AddrRangeSeq) (safemem.BlockSeq, error) {
switch ars.NumRanges() {
case 0:
return safemem.BlockSeq{}, nil
@@ -124,7 +125,7 @@ func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, erro
}
}
-func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) {
+func (b *BytesIO) blockFromAddrRange(ar hostarch.AddrRange) (safemem.Block, error) {
n, err := b.rangeCheck(ar.Start, int(ar.Length()))
if n == 0 {
return safemem.Block{}, err
@@ -136,6 +137,6 @@ func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) {
func BytesIOSequence(buf []byte) IOSequence {
return IOSequence{
IO: &BytesIO{buf},
- Addrs: AddrRangeSeqOf(AddrRange{0, Addr(len(buf))}),
+ Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(len(buf))}),
}
}
diff --git a/pkg/usermem/bytes_io_unsafe.go b/pkg/usermem/bytes_io_unsafe.go
index 20de5037d..dcd5c81d1 100644
--- a/pkg/usermem/bytes_io_unsafe.go
+++ b/pkg/usermem/bytes_io_unsafe.go
@@ -20,10 +20,11 @@ import (
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// SwapUint32 implements IO.SwapUint32.
-func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) {
+func (b *BytesIO) SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts IOOpts) (uint32, error) {
if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil {
return 0, rngErr
}
@@ -31,7 +32,7 @@ func (b *BytesIO) SwapUint32(ctx context.Context, addr Addr, new uint32, opts IO
}
// CompareAndSwapUint32 implements IO.CompareAndSwapUint32.
-func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) {
+func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts IOOpts) (uint32, error) {
if _, rngErr := b.rangeCheck(addr, 4); rngErr != nil {
return 0, rngErr
}
@@ -39,7 +40,7 @@ func (b *BytesIO) CompareAndSwapUint32(ctx context.Context, addr Addr, old, new
}
// LoadUint32 implements IO.LoadUint32.
-func (b *BytesIO) LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) {
+func (b *BytesIO) LoadUint32(ctx context.Context, addr hostarch.Addr, opts IOOpts) (uint32, error) {
if _, err := b.rangeCheck(addr, 4); err != nil {
return 0, err
}
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index dc2571154..0d6d25e50 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
+
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// IO provides access to the contents of a virtual memory space.
@@ -37,7 +39,7 @@ type IO interface {
// any following locks in the lock order.
//
// Postconditions: CopyOut does not retain src.
- CopyOut(ctx context.Context, addr Addr, src []byte, opts IOOpts) (int, error)
+ CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts IOOpts) (int, error)
// CopyIn copies len(dst) bytes from the memory mapped at addr to dst.
// It returns the number of bytes copied. If the number of bytes copied is
@@ -47,7 +49,7 @@ type IO interface {
// any following locks in the lock order.
//
// Postconditions: CopyIn does not retain dst.
- CopyIn(ctx context.Context, addr Addr, dst []byte, opts IOOpts) (int, error)
+ CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts IOOpts) (int, error)
// ZeroOut sets toZero bytes to 0, starting at addr. It returns the number
// of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a
@@ -57,7 +59,7 @@ type IO interface {
// * The caller must not hold mm.MemoryManager.mappingMu or any
// following locks in the lock order.
// * toZero >= 0.
- ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error)
+ ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error)
// CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at
// ars. It returns the number of bytes copied, which may be less than the
@@ -72,7 +74,7 @@ type IO interface {
// following locks in the lock order.
// * src.ReadToBlocks must not block on mm.MemoryManager.activeMu or
// any preceding locks in the lock order.
- CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error)
+ CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error)
// CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to
// dst. It returns the number of bytes copied. CopyInTo may return a
@@ -86,7 +88,7 @@ type IO interface {
// following locks in the lock order.
// * dst.WriteFromBlocks must not block on mm.MemoryManager.activeMu or
// any preceding locks in the lock order.
- CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error)
+ CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error)
// TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst
// at most once, which is unnecessary in most cases, forces implementations
@@ -101,7 +103,7 @@ type IO interface {
// * The caller must not hold mm.MemoryManager.mappingMu or any
// following locks in the lock order.
// * addr must be aligned to a 4-byte boundary.
- SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error)
+ SwapUint32(ctx context.Context, addr hostarch.Addr, new uint32, opts IOOpts) (uint32, error)
// CompareAndSwapUint32 atomically compares the uint32 value at addr to
// old; if they are equal, the value in memory is replaced by new. In
@@ -111,7 +113,7 @@ type IO interface {
// * The caller must not hold mm.MemoryManager.mappingMu or any
// following locks in the lock order.
// * addr must be aligned to a 4-byte boundary.
- CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error)
+ CompareAndSwapUint32(ctx context.Context, addr hostarch.Addr, old, new uint32, opts IOOpts) (uint32, error)
// LoadUint32 atomically loads the uint32 value at addr and returns it.
//
@@ -119,7 +121,7 @@ type IO interface {
// * The caller must not hold mm.MemoryManager.mappingMu or any
// following locks in the lock order.
// * addr must be aligned to a 4-byte boundary.
- LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error)
+ LoadUint32(ctx context.Context, addr hostarch.Addr, opts IOOpts) (uint32, error)
}
// IOOpts contains options applicable to all IO methods.
@@ -142,7 +144,7 @@ type IOOpts struct {
type IOReadWriter struct {
Ctx context.Context
IO IO
- Addr Addr
+ Addr hostarch.Addr
Opts IOOpts
}
@@ -159,7 +161,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) {
rw.Addr = end
} else {
// Disallow wraparound.
- rw.Addr = ^Addr(0)
+ rw.Addr = ^hostarch.Addr(0)
if err != nil {
err = syserror.EFAULT
}
@@ -175,7 +177,7 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) {
rw.Addr = end
} else {
// Disallow wraparound.
- rw.Addr = ^Addr(0)
+ rw.Addr = ^hostarch.Addr(0)
if err != nil {
err = syserror.EFAULT
}
@@ -197,7 +199,7 @@ const (
//
// Preconditions: Same as IO.CopyFromUser, plus:
// * maxlen >= 0.
-func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) {
+func CopyStringIn(ctx context.Context, uio IO, addr hostarch.Addr, maxlen int, opts IOOpts) (string, error) {
initLen := maxlen
if initLen > copyStringMaxInitBufLen {
initLen = copyStringMaxInitBufLen
@@ -251,12 +253,12 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
// the maximum, it returns a non-nil error explaining why.
//
// Preconditions: Same as IO.CopyOut.
-func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) {
+func CopyOutVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, src []byte, opts IOOpts) (int, error) {
var done int
for !ars.IsEmpty() && done < len(src) {
ar := ars.Head()
cplen := len(src) - done
- if Addr(cplen) >= ar.Length() {
+ if hostarch.Addr(cplen) >= ar.Length() {
cplen = int(ar.Length())
}
n, err := uio.CopyOut(ctx, ar.Start, src[done:done+cplen], opts)
@@ -275,12 +277,12 @@ func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts
// maximum, it returns a non-nil error explaining why.
//
// Preconditions: Same as IO.CopyIn.
-func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) {
+func CopyInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, dst []byte, opts IOOpts) (int, error) {
var done int
for !ars.IsEmpty() && done < len(dst) {
ar := ars.Head()
cplen := len(dst) - done
- if Addr(cplen) >= ar.Length() {
+ if hostarch.Addr(cplen) >= ar.Length() {
cplen = int(ar.Length())
}
n, err := uio.CopyIn(ctx, ar.Start, dst[done:done+cplen], opts)
@@ -299,12 +301,12 @@ func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts I
// maximum, it returns a non-nil error explaining why.
//
// Preconditions: Same as IO.ZeroOut.
-func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) {
+func ZeroOutVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) {
var done int64
for !ars.IsEmpty() && done < toZero {
ar := ars.Head()
cplen := toZero - done
- if Addr(cplen) >= ar.Length() {
+ if hostarch.Addr(cplen) >= ar.Length() {
cplen = int64(ar.Length())
}
n, err := uio.ZeroOut(ctx, ar.Start, cplen, opts)
@@ -352,7 +354,7 @@ func isASCIIWhitespace(b byte) bool {
// - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0.
//
// Preconditions: Same as CopyInVec.
-func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) {
+func CopyInt32StringsInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) {
if len(dsts) == 0 {
return 0, nil
}
@@ -403,7 +405,7 @@ func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts [
// CopyInt32StringInVec is equivalent to CopyInt32StringsInVec, but copies at
// most one int32.
-func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) {
+func CopyInt32StringInVec(ctx context.Context, uio IO, ars hostarch.AddrRangeSeq, dst *int32, opts IOOpts) (int64, error) {
dsts := [1]int32{*dst}
n, err := CopyInt32StringsInVec(ctx, uio, ars, dsts[:], opts)
*dst = dsts[0]
@@ -413,7 +415,7 @@ func CopyInt32StringInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst *in
// IOSequence holds arguments to IO methods.
type IOSequence struct {
IO IO
- Addrs AddrRangeSeq
+ Addrs hostarch.AddrRangeSeq
Opts IOOpts
}
@@ -444,28 +446,28 @@ func (s IOSequence) NumBytes() int64 {
// DropFirst returns a copy of s with s.Addrs.DropFirst(n).
//
-// Preconditions: Same as AddrRangeSeq.DropFirst.
+// Preconditions: Same as hostarch.AddrRangeSeq.DropFirst.
func (s IOSequence) DropFirst(n int) IOSequence {
return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts}
}
// DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n).
//
-// Preconditions: Same as AddrRangeSeq.DropFirst64.
+// Preconditions: Same as hostarch.AddrRangeSeq.DropFirst64.
func (s IOSequence) DropFirst64(n int64) IOSequence {
return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts}
}
// TakeFirst returns a copy of s with s.Addrs.TakeFirst(n).
//
-// Preconditions: Same as AddrRangeSeq.TakeFirst.
+// Preconditions: Same as hostarch.AddrRangeSeq.TakeFirst.
func (s IOSequence) TakeFirst(n int) IOSequence {
return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts}
}
// TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n).
//
-// Preconditions: Same as AddrRangeSeq.TakeFirst64.
+// Preconditions: Same as hostarch.AddrRangeSeq.TakeFirst64.
func (s IOSequence) TakeFirst64(n int64) IOSequence {
return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts}
}
diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go
index da60b0cc7..9b697b593 100644
--- a/pkg/usermem/usermem_test.go
+++ b/pkg/usermem/usermem_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -106,7 +107,7 @@ func TestBytesIOZeroOutFailure(t *testing.T) {
func TestBytesIOCopyOutFromSuccess(t *testing.T) {
b := newBytesIOString("ABCDEFGH")
- n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
{Start: 4, End: 7},
{Start: 1, End: 4},
}), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{})
@@ -120,7 +121,7 @@ func TestBytesIOCopyOutFromSuccess(t *testing.T) {
func TestBytesIOCopyOutFromFailure(t *testing.T) {
b := newBytesIOString("ABCDE")
- n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
{Start: 1, End: 4},
{Start: 4, End: 7},
}), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{})
@@ -135,7 +136,7 @@ func TestBytesIOCopyOutFromFailure(t *testing.T) {
func TestBytesIOCopyInToSuccess(t *testing.T) {
b := newBytesIOString("AfoobarH")
var dst bytes.Buffer
- n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
{Start: 4, End: 7},
{Start: 1, End: 4},
}), safemem.FromIOWriter{&dst}, IOOpts{})
@@ -150,7 +151,7 @@ func TestBytesIOCopyInToSuccess(t *testing.T) {
func TestBytesIOCopyInToFailure(t *testing.T) {
b := newBytesIOString("Afoob")
var dst bytes.Buffer
- n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{
+ n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
{Start: 1, End: 4},
{Start: 4, End: 7},
}), safemem.FromIOWriter{&dst}, IOOpts{})
diff --git a/runsc/BUILD b/runsc/BUILD
index 3b91b984a..7a7dcc8d5 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -9,6 +9,7 @@ go_binary(
"version.go",
],
pure = True,
+ tags = ["staging"],
visibility = [
"//visibility:public",
],
@@ -49,5 +50,4 @@ sh_test(
srcs = ["version_test.sh"],
args = ["$(location :runsc)"],
data = [":runsc"],
- tags = ["noguitar"],
)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 67307ab3c..a79afbdc4 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/cleanup",
"//pkg/context",
"//pkg/control/server",
+ "//pkg/coverage",
"//pkg/cpuid",
"//pkg/eventchannel",
"//pkg/fd",
@@ -57,6 +58,7 @@ go_library(
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/fs/tty",
"//pkg/sentry/fs/user",
+ "//pkg/sentry/fsimpl/cgroupfs",
"//pkg/sentry/fsimpl/devpts",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/fsimpl/fuse",
@@ -66,6 +68,7 @@ go_library(
"//pkg/sentry/fsimpl/proc",
"//pkg/sentry/fsimpl/sys",
"//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/fsimpl/verity",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel:uncaught_signal_go_proto",
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 1ae76d7d7..05b721b28 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -400,7 +400,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
// Set up the restore environment.
ctx := k.SupervisorContext()
- mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled)
+ mntr := newContainerMounter(&cm.l.root, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled)
if kernel.VFS2Enabled {
ctx, err = mntr.configureRestore(ctx, cm.l.root.conf)
if err != nil {
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 32adde643..3c0cef6db 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/gofer"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
gofervfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
@@ -103,17 +104,22 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name
// compileMounts returns the supported mounts from the mount spec, adding any
// mandatory mounts that are required by the OCI specification.
-func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount {
+func compileMounts(spec *specs.Spec, conf *config.Config, vfs2Enabled bool) []specs.Mount {
// Keep track of whether proc and sys were mounted.
var procMounted, sysMounted, devMounted, devptsMounted bool
var mounts []specs.Mount
// Mount all submounts from the spec.
for _, m := range spec.Mounts {
- if !vfs2Enabled && !specutils.IsVFS1SupportedDevMount(m) {
+ if !specutils.IsSupportedDevMount(m, vfs2Enabled) {
log.Warningf("ignoring dev mount at %q", m.Destination)
continue
}
+ // Unconditionally drop any cgroupfs mounts. If requested, we'll add our
+ // own below.
+ if m.Type == cgroupfs.Name {
+ continue
+ }
switch filepath.Clean(m.Destination) {
case "/proc":
procMounted = true
@@ -132,6 +138,24 @@ func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount {
// Mount proc and sys even if the user did not ask for it, as the spec
// says we SHOULD.
var mandatoryMounts []specs.Mount
+
+ if conf.Cgroupfs {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: tmpfsvfs2.Name,
+ Destination: "/sys/fs/cgroup",
+ })
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: cgroupfs.Name,
+ Destination: "/sys/fs/cgroup/memory",
+ Options: []string{"memory"},
+ })
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: cgroupfs.Name,
+ Destination: "/sys/fs/cgroup/cpu",
+ Options: []string{"cpu"},
+ })
+ }
+
if !procMounted {
mandatoryMounts = append(mandatoryMounts, specs.Mount{
Type: procvfs2.Name,
@@ -248,6 +272,10 @@ func isSupportedMountFlag(fstype, opt string) bool {
ok, err := parseMountOption(opt, tmpfsAllowedData...)
return ok && err == nil
}
+ if fstype == cgroupfs.Name {
+ ok, err := parseMountOption(opt, cgroupfs.SupportedMountOptions...)
+ return ok && err == nil
+ }
return false
}
@@ -572,11 +600,11 @@ type containerMounter struct {
hints *podMountHints
}
-func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter {
+func newContainerMounter(info *containerInfo, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter {
return &containerMounter{
- root: spec.Root,
- mounts: compileMounts(spec, vfs2Enabled),
- fds: fdDispenser{fds: goferFDs},
+ root: info.spec.Root,
+ mounts: compileMounts(info.spec, info.conf, vfs2Enabled),
+ fds: fdDispenser{fds: info.goferFDs},
k: k,
hints: hints,
}
@@ -795,7 +823,13 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M
opts = p9MountData(fd, c.getMountAccessType(conf, m), conf.VFS2)
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
-
+ case cgroupfs.Name:
+ fsName = m.Type
+ var err error
+ opts, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...)
+ if err != nil {
+ return "", nil, false, err
+ }
default:
log.Warningf("ignoring unknown filesystem type %q", m.Type)
}
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 774621970..798c1a7a7 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/coverage"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
@@ -491,10 +492,6 @@ func (l *Loader) Destroy() {
// save/restore.
l.k.Release()
- // All sentry-created resources should have been released at this point;
- // check for reference leaks.
- refsvfs2.DoLeakCheck()
-
// In the success case, stdioFDs and goferFDs will only contain
// released/closed FDs that ownership has been passed over to host FDs and
// gofer sessions. Close them here in case of failure.
@@ -752,7 +749,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
// Setup the child container file system.
l.startGoferMonitor(cid, info.goferFDs)
- mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints, kernel.VFS2Enabled)
+ mntr := newContainerMounter(info, l.k, l.mountHints, kernel.VFS2Enabled)
if root {
if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil {
return nil, nil, nil, err
@@ -1000,6 +997,15 @@ func (l *Loader) waitContainer(cid string, waitStatus *uint32) error {
// consider the container exited.
ws := l.wait(tg)
*waitStatus = ws
+
+ // Check for leaks and write coverage report after the root container has
+ // exited. This guarantees that the report is written in cases where the
+ // sandbox is killed by a signal after the ContainerWait request is completed.
+ if l.root.procArgs.ContainerID == cid {
+ // All sentry-created resources should have been released at this point.
+ refsvfs2.DoLeakCheck()
+ coverage.Report()
+ }
return nil
}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 8b39bc59a..93c476971 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -439,7 +439,13 @@ func TestCreateMountNamespace(t *testing.T) {
}
defer cleanup()
- mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}, false /* vfs2Enabled */)
+ info := containerInfo{
+ conf: conf,
+ spec: &tc.spec,
+ goferFDs: []*fd.FD{fd.New(sandEnd)},
+ }
+
+ mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */)
mns, err := mntr.createMountNamespace(ctx, conf)
if err != nil {
t.Fatalf("failed to create mount namespace: %v", err)
@@ -479,7 +485,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) {
defer l.Destroy()
defer loaderCleanup()
- mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints, true /* vfs2Enabled */)
+ mntr := newContainerMounter(&l.root, l.k, l.mountHints, true /* vfs2Enabled */)
if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil {
t.Fatalf("failed process hints: %v", err)
}
@@ -702,7 +708,12 @@ func TestRestoreEnvironment(t *testing.T) {
for _, ioFD := range tc.ioFDs {
ioFDs = append(ioFDs, fd.New(ioFD))
}
- mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}, false /* vfs2Enabled */)
+ info := containerInfo{
+ conf: conf,
+ spec: tc.spec,
+ goferFDs: ioFDs,
+ }
+ mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */)
actualRenv, err := mntr.createRestoreEnvironment(conf)
if !tc.errorExpected && err != nil {
t.Fatalf("could not create restore environment for test:%s", tc.name)
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 9b3dacf46..7d8fd0483 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -16,6 +16,7 @@ package boot
import (
"fmt"
+ "path"
"sort"
"strings"
@@ -29,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/devices/ttydev"
"gvisor.dev/gvisor/pkg/sentry/devices/tundev"
"gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/fuse"
@@ -37,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/verity"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -50,6 +53,10 @@ func registerFilesystems(k *kernel.Kernel) error {
creds := auth.NewRootCredentials(k.RootUserNamespace())
vfsObj := k.VFS()
+ vfsObj.MustRegisterFilesystemType(cgroupfs.Name, &cgroupfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserList: true,
// TODO(b/29356795): Users may mount this once the terminals are in a
@@ -60,6 +67,10 @@ func registerFilesystems(k *kernel.Kernel) error {
AllowUserMount: true,
AllowUserList: true,
})
+ vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
vfsObj.MustRegisterFilesystemType(gofer.Name, &gofer.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserList: true,
})
@@ -79,9 +90,9 @@ func registerFilesystems(k *kernel.Kernel) error {
AllowUserMount: true,
AllowUserList: true,
})
- vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
- AllowUserMount: true,
+ vfsObj.MustRegisterFilesystemType(verity.Name, &verity.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserList: true,
+ AllowUserMount: true,
})
// Setup files in devtmpfs.
@@ -472,6 +483,12 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
var data []string
var iopts interface{}
+ verityData, verityOpts, verityRequested, remainingMOpts, err := parseVerityMountOptions(m.Options)
+ if err != nil {
+ return "", nil, false, err
+ }
+ m.Options = remainingMOpts
+
// Find filesystem name and FS specific data field.
switch m.Type {
case devpts.Name, devtmpfs.Name, proc.Name, sys.Name:
@@ -502,6 +519,13 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
+ case cgroupfs.Name:
+ var err error
+ data, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...)
+ if err != nil {
+ return "", nil, false, err
+ }
+
default:
log.Warningf("ignoring unknown filesystem type %q", m.Type)
return "", nil, false, nil
@@ -530,9 +554,75 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
}
}
+ if verityRequested {
+ verityData = verityData + "root_name=" + path.Base(m.Mount.Destination)
+ verityOpts.LowerName = fsName
+ verityOpts.LowerGetFSOptions = opts.GetFilesystemOptions
+ fsName = verity.Name
+ opts = &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: verityData,
+ InternalData: verityOpts,
+ },
+ InternalMount: true,
+ }
+ }
+
return fsName, opts, useOverlay, nil
}
+func parseKeyValue(s string) (string, string, bool) {
+ tokens := strings.SplitN(s, "=", 2)
+ if len(tokens) < 2 {
+ return "", "", false
+ }
+ return strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1]), true
+}
+
+// parseAndFilterOptions scans the provided mount options for verity-related
+// mount options. It returns the parsed set of verity mount options, as well as
+// the filtered set of mount options unrelated to verity.
+func parseVerityMountOptions(mopts []string) (string, verity.InternalFilesystemOptions, bool, []string, error) {
+ nonVerity := []string{}
+ found := false
+ var rootHash string
+ verityOpts := verity.InternalFilesystemOptions{
+ Action: verity.PanicOnViolation,
+ }
+ for _, o := range mopts {
+ if !strings.HasPrefix(o, "verity.") {
+ nonVerity = append(nonVerity, o)
+ continue
+ }
+
+ k, v, ok := parseKeyValue(o)
+ if !ok {
+ return "", verityOpts, found, nonVerity, fmt.Errorf("invalid verity mount option with no value: %q", o)
+ }
+
+ found = true
+ switch k {
+ case "verity.roothash":
+ rootHash = v
+ case "verity.action":
+ switch v {
+ case "error":
+ verityOpts.Action = verity.ErrorOnViolation
+ case "panic":
+ verityOpts.Action = verity.PanicOnViolation
+ default:
+ log.Warningf("Invalid verity action %q", v)
+ verityOpts.Action = verity.PanicOnViolation
+ }
+ default:
+ return "", verityOpts, found, nonVerity, fmt.Errorf("unknown verity mount option: %q", k)
+ }
+ }
+ verityOpts.AllowRuntimeEnable = len(rootHash) == 0
+ verityData := "root_hash=" + rootHash + ","
+ return verityData, verityOpts, found, nonVerity, nil
+}
+
// mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so.
// Technically we don't have to mount tmpfs at /tmp, as we could just rely on
// the host /tmp, but this is a nice optimization, and fixes some apps that call
diff --git a/runsc/cli/BUILD b/runsc/cli/BUILD
index f1e3cce68..360e3cea6 100644
--- a/runsc/cli/BUILD
+++ b/runsc/cli/BUILD
@@ -10,8 +10,10 @@ go_library(
"//runsc:__pkg__",
],
deps = [
+ "//pkg/coverage",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/platform",
"//runsc/cmd",
"//runsc/config",
diff --git a/runsc/cli/main.go b/runsc/cli/main.go
index a3c515f4b..76184cd9c 100644
--- a/runsc/cli/main.go
+++ b/runsc/cli/main.go
@@ -27,8 +27,10 @@ import (
"github.com/google/subcommands"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/coverage"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/runsc/cmd"
"gvisor.dev/gvisor/runsc/config"
@@ -50,6 +52,7 @@ var (
logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.")
+ coverageFD = flag.Int("coverage-fd", -1, "file descriptor to write Go coverage output.")
)
// Main is the main entrypoint.
@@ -86,6 +89,7 @@ func Main(version string) {
subcommands.Register(new(cmd.Symbolize), "")
subcommands.Register(new(cmd.Wait), "")
subcommands.Register(new(cmd.Mitigate), "")
+ subcommands.Register(new(cmd.VerityPrepare), "")
// Register internal commands with the internal group name. This causes
// them to be sorted below the user-facing commands with empty group.
@@ -204,6 +208,10 @@ func Main(version string) {
} else if conf.AlsoLogToStderr {
e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)}
}
+ if *coverageFD >= 0 {
+ f := os.NewFile(uintptr(*coverageFD), "coverage file")
+ coverage.EnableReport(f)
+ }
log.SetTarget(e)
@@ -233,6 +241,9 @@ func Main(version string) {
// Call the subcommand and pass in the configuration.
var ws unix.WaitStatus
subcmdCode := subcommands.Execute(context.Background(), conf, &ws)
+ // Check for leaks and write coverage report before os.Exit().
+ refsvfs2.DoLeakCheck()
+ coverage.Report()
if subcmdCode == subcommands.ExitSuccess {
log.Infof("Exiting with status: %v", ws)
if ws.Signaled() {
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 2c3b4058b..39c8ff603 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -23,6 +23,7 @@ go_library(
"kill.go",
"list.go",
"mitigate.go",
+ "mitigate_extras.go",
"path.go",
"pause.go",
"ps.go",
@@ -35,6 +36,7 @@ go_library(
"statefile.go",
"symbolize.go",
"syscalls.go",
+ "verity_prepare.go",
"wait.go",
],
visibility = [
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 455c57692..5485db149 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -126,9 +126,8 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
Hostname: hostname,
}
- specutils.LogSpec(spec)
-
cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000))
+
if conf.Network == config.NetworkNone {
addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace})
@@ -154,55 +153,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
}
}
- out, err := json.Marshal(spec)
- if err != nil {
- return Errorf("Error to marshal spec: %v", err)
- }
- tmpDir, err := ioutil.TempDir("", "runsc-do")
- if err != nil {
- return Errorf("Error to create tmp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- log.Infof("Changing configuration RootDir to %q", tmpDir)
- conf.RootDir = tmpDir
-
- cfgPath := filepath.Join(tmpDir, "config.json")
- if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil {
- return Errorf("Error write spec: %v", err)
- }
-
- containerArgs := container.Args{
- ID: cid,
- Spec: spec,
- BundleDir: tmpDir,
- Attached: true,
- }
- ct, err := container.New(conf, containerArgs)
- if err != nil {
- return Errorf("creating container: %v", err)
- }
- defer ct.Destroy()
-
- if err := ct.Start(conf); err != nil {
- return Errorf("starting container: %v", err)
- }
-
- // Forward signals to init in the container. Thus if we get SIGINT from
- // ^C, the container gracefully exit, and we can clean up.
- //
- // N.B. There is a still a window before this where a signal may kill
- // this process, skipping cleanup.
- stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */)
- defer stopForwarding()
-
- ws, err := ct.Wait()
- if err != nil {
- return Errorf("waiting for container: %v", err)
- }
-
- *waitStatus = ws
- return subcommands.ExitSuccess
+ return startContainerAndWait(spec, conf, cid, waitStatus)
}
func addNamespace(spec *specs.Spec, ns specs.LinuxNamespace) {
@@ -397,3 +348,58 @@ func calculatePeerIP(ip string) (string, error) {
}
return fmt.Sprintf("%s.%s.%s.%d", parts[0], parts[1], parts[2], n), nil
}
+
+func startContainerAndWait(spec *specs.Spec, conf *config.Config, cid string, waitStatus *unix.WaitStatus) subcommands.ExitStatus {
+ specutils.LogSpec(spec)
+
+ out, err := json.Marshal(spec)
+ if err != nil {
+ return Errorf("Error to marshal spec: %v", err)
+ }
+ tmpDir, err := ioutil.TempDir("", "runsc-do")
+ if err != nil {
+ return Errorf("Error to create tmp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ log.Infof("Changing configuration RootDir to %q", tmpDir)
+ conf.RootDir = tmpDir
+
+ cfgPath := filepath.Join(tmpDir, "config.json")
+ if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil {
+ return Errorf("Error write spec: %v", err)
+ }
+
+ containerArgs := container.Args{
+ ID: cid,
+ Spec: spec,
+ BundleDir: tmpDir,
+ Attached: true,
+ }
+
+ ct, err := container.New(conf, containerArgs)
+ if err != nil {
+ return Errorf("creating container: %v", err)
+ }
+ defer ct.Destroy()
+
+ if err := ct.Start(conf); err != nil {
+ return Errorf("starting container: %v", err)
+ }
+
+ // Forward signals to init in the container. Thus if we get SIGINT from
+ // ^C, the container gracefully exit, and we can clean up.
+ //
+ // N.B. There is a still a window before this where a signal may kill
+ // this process, skipping cleanup.
+ stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */)
+ defer stopForwarding()
+
+ ws, err := ct.Wait()
+ if err != nil {
+ return Errorf("waiting for container: %v", err)
+ }
+
+ *waitStatus = ws
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 4cb0164dd..6a755ecb6 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -176,7 +176,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
mountIdx := 1 // first one is the root
for _, m := range spec.Mounts {
- if specutils.Is9PMount(m) {
+ if specutils.Is9PMount(m, conf.VFS2) {
cfg := fsgofer.Config{
ROMount: isReadonlyMount(m.Options) || conf.Overlay,
HostUDS: conf.FSGoferHostUDS,
@@ -350,7 +350,7 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error {
// creates directories as needed.
func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error {
for _, m := range mounts {
- if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) {
+ if !specutils.Is9PMount(m, conf.VFS2) {
continue
}
@@ -390,7 +390,7 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error {
func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) {
cleanMounts := make([]specs.Mount, 0, len(mounts))
for _, m := range mounts {
- if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) {
+ if !specutils.Is9PMount(m, conf.VFS2) {
cleanMounts = append(cleanMounts, m)
continue
}
diff --git a/runsc/cmd/mitigate.go b/runsc/cmd/mitigate.go
index fddf0e0dd..d37ab80ba 100644
--- a/runsc/cmd/mitigate.go
+++ b/runsc/cmd/mitigate.go
@@ -40,8 +40,8 @@ type Mitigate struct {
reverse bool
// Path to file to read to create CPUSet.
path string
- // Callback to check if a given thread is vulnerable.
- vulnerable func(other mitigate.Thread) bool
+ // Extra data for post mitigate operations.
+ data string
}
// Name implements subcommands.command.name.
@@ -54,19 +54,20 @@ func (*Mitigate) Synopsis() string {
return "mitigate mitigates the underlying system against side channel attacks"
}
-// Usage implments Usage for cmd.Mitigate.
+// Usage implements Usage for cmd.Mitigate.
func (m Mitigate) Usage() string {
- return `mitigate [flags]
+ return fmt.Sprintf(`mitigate [flags]
mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot.
-The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.`
+The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.%s`, m.usage())
}
// SetFlags sets flags for the command Mitigate.
func (m *Mitigate) SetFlags(f *flag.FlagSet) {
f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system")
f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs")
+ m.setFlags(f)
}
// Execute implements subcommands.Command.Execute.
@@ -81,13 +82,17 @@ func (m *Mitigate) Execute(_ context.Context, f *flag.FlagSet, args ...interface
m.path = allPossibleCPUs
}
- m.vulnerable = func(other mitigate.Thread) bool {
- return other.IsVulnerable()
+ set, err := m.doExecute()
+ if err != nil {
+ return Errorf("Execute failed: %v", err)
+ }
+
+ if m.data == "" {
+ return subcommands.ExitSuccess
}
- if _, err := m.doExecute(); err != nil {
- log.Warningf("Execute failed: %v", err)
- return subcommands.ExitFailure
+ if err = m.postMitigate(set); err != nil {
+ return Errorf("Post Mitigate failed: %v", err)
}
return subcommands.ExitSuccess
@@ -98,32 +103,26 @@ func (m *Mitigate) doExecute() (mitigate.CPUSet, error) {
if m.dryRun {
log.Infof("Running with DryRun. No cpu settings will be changed.")
}
+ data, err := ioutil.ReadFile(m.path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read %s: %w", m.path, err)
+ }
if m.reverse {
- data, err := ioutil.ReadFile(m.path)
- if err != nil {
- return nil, fmt.Errorf("failed to read %s: %v", m.path, err)
- }
-
set, err := m.doReverse(data)
if err != nil {
- return nil, fmt.Errorf("reverse operation failed: %v", err)
+ return nil, fmt.Errorf("reverse operation failed: %w", err)
}
return set, nil
}
-
- data, err := ioutil.ReadFile(m.path)
- if err != nil {
- return nil, fmt.Errorf("failed to read %s: %v", m.path, err)
- }
set, err := m.doMitigate(data)
if err != nil {
- return nil, fmt.Errorf("mitigate operation failed: %v", err)
+ return nil, fmt.Errorf("mitigate operation failed: %w", err)
}
return set, nil
}
func (m *Mitigate) doMitigate(data []byte) (mitigate.CPUSet, error) {
- set, err := mitigate.NewCPUSet(data, m.vulnerable)
+ set, err := mitigate.NewCPUSet(data)
if err != nil {
return nil, err
}
@@ -139,7 +138,7 @@ func (m *Mitigate) doMitigate(data []byte) (mitigate.CPUSet, error) {
continue
}
if err := t.Disable(); err != nil {
- return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err)
+ return nil, fmt.Errorf("error disabling thread: %s err: %w", t, err)
}
}
log.Infof("Shutdown successful.")
@@ -164,7 +163,7 @@ func (m *Mitigate) doReverse(data []byte) (mitigate.CPUSet, error) {
continue
}
if err := t.Enable(); err != nil {
- return nil, fmt.Errorf("error enabling thread: %s err: %v", t, err)
+ return nil, fmt.Errorf("error enabling thread: %s err: %w", t, err)
}
}
log.Infof("Enable successful.")
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/runsc/cmd/mitigate_extras.go
index c9dc7e773..2cb2833f0 100644
--- a/pkg/tcpip/transport/tcp/rack_state.go
+++ b/runsc/cmd/mitigate_extras.go
@@ -1,4 +1,4 @@
-// Copyright 2020 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,18 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package tcp
+package cmd
import (
- "time"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/mitigate"
)
-// saveXmitTime is invoked by stateify.
-func (rc *rackControl) saveXmitTime() unixTime {
- return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()}
+// usage returns any extra bits of the usage string.
+func (m *Mitigate) usage() string {
+ return ""
}
-// loadXmitTime is invoked by stateify.
-func (rc *rackControl) loadXmitTime(unix unixTime) {
- rc.xmitTime = time.Unix(unix.second, unix.nano)
+// setFlags sets extra flags for the command Mitigate.
+func (m *Mitigate) setFlags(f *flag.FlagSet) {}
+
+// postMitigate handles any postMitigate actions.
+func (m *Mitigate) postMitigate(_ mitigate.CPUSet) error {
+ return nil
}
diff --git a/runsc/cmd/mitigate_test.go b/runsc/cmd/mitigate_test.go
index 163fece42..5a76667e3 100644
--- a/runsc/cmd/mitigate_test.go
+++ b/runsc/cmd/mitigate_test.go
@@ -21,7 +21,6 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/mitigate"
"gvisor.dev/gvisor/runsc/mitigate/mock"
)
@@ -84,9 +83,6 @@ power management::84
t.Run(tc.name, func(t *testing.T) {
m := &Mitigate{
dryRun: true,
- vulnerable: func(other mitigate.Thread) bool {
- return other.IsVulnerable()
- },
}
m.doExecuteTest(t, "Mitigate", tc.mitigateData, tc.mitigateCPU, tc.mitigateError)
@@ -104,9 +100,6 @@ func TestExecuteSmoke(t *testing.T) {
m := &Mitigate{
dryRun: true,
- vulnerable: func(other mitigate.Thread) bool {
- return other.IsVulnerable()
- },
}
m.doExecuteTest(t, "Mitigate", string(smokeMitigate), 0, nil)
diff --git a/runsc/cmd/symbolize.go b/runsc/cmd/symbolize.go
index fc0c69358..0fa4bfda1 100644
--- a/runsc/cmd/symbolize.go
+++ b/runsc/cmd/symbolize.go
@@ -65,13 +65,15 @@ func (c *Symbolize) Execute(_ context.Context, f *flag.FlagSet, args ...interfac
f.Usage()
return subcommands.ExitUsageError
}
- if !coverage.KcovAvailable() {
+ if !coverage.Available() {
return Errorf("symbolize can only be used when coverage is available.")
}
coverage.InitCoverageData()
if c.dumpAll {
- coverage.WriteAllBlocks(os.Stdout)
+ if err := coverage.WriteAllBlocks(os.Stdout); err != nil {
+ return Errorf("Failed to write out blocks: %v", err)
+ }
return subcommands.ExitSuccess
}
diff --git a/runsc/cmd/verity_prepare.go b/runsc/cmd/verity_prepare.go
new file mode 100644
index 000000000..66128b2a3
--- /dev/null
+++ b/runsc/cmd/verity_prepare.go
@@ -0,0 +1,108 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "os"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/runsc/config"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// VerityPrepare implements subcommands.Commands for the "verity-prepare"
+// command. It sets up a sandbox with a writable verity mount mapped to "--dir",
+// and executes the verity measure tool specified by "--tool" in the sandbox. It
+// is intended to prepare --dir to be mounted as a verity filesystem.
+type VerityPrepare struct {
+ root string
+ tool string
+ dir string
+}
+
+// Name implements subcommands.Command.Name.
+func (*VerityPrepare) Name() string {
+ return "verity-prepare"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*VerityPrepare) Synopsis() string {
+ return "Generates the data structures necessary to enable verityfs on a filesystem."
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*VerityPrepare) Usage() string {
+ return "verity-prepare --tool=<measure_tool> --dir=<path>"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *VerityPrepare) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&c.root, "root", "/", `path to the root directory, defaults to "/"`)
+ f.StringVar(&c.tool, "tool", "", "path to the verity measure_tool")
+ f.StringVar(&c.dir, "dir", "", "path to the directory to be hashed")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *VerityPrepare) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ conf := args[0].(*config.Config)
+ waitStatus := args[1].(*unix.WaitStatus)
+
+ hostname, err := os.Hostname()
+ if err != nil {
+ return Errorf("Error to retrieve hostname: %v", err)
+ }
+
+ // Map the entire host file system.
+ absRoot, err := resolvePath(c.root)
+ if err != nil {
+ return Errorf("Error resolving root: %v", err)
+ }
+
+ spec := &specs.Spec{
+ Root: &specs.Root{
+ Path: absRoot,
+ },
+ Process: &specs.Process{
+ Cwd: absRoot,
+ Args: []string{c.tool, "--path", "/verityroot"},
+ Env: os.Environ(),
+ Capabilities: specutils.AllCapabilities(),
+ },
+ Hostname: hostname,
+ Mounts: []specs.Mount{
+ specs.Mount{
+ Source: c.dir,
+ Destination: "/verityroot",
+ Type: "bind",
+ Options: []string{"verity.roothash="},
+ },
+ },
+ }
+
+ cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000))
+
+ // Force no networking, it is not necessary to run the verity measure tool.
+ conf.Network = config.NetworkNone
+
+ conf.Verity = true
+
+ return startContainerAndWait(spec, conf, cid, waitStatus)
+}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index 1e5858837..fa550ebf7 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -55,6 +55,9 @@ type Config struct {
// PanicLog is the path to log GO's runtime messages, if not empty.
PanicLog string `flag:"panic-log"`
+ // CoverageReport is the path to write Go coverage information, if not empty.
+ CoverageReport string `flag:"coverage-report"`
+
// DebugLogFormat is the log format for debug.
DebugLogFormat string `flag:"debug-log-format"`
@@ -172,6 +175,9 @@ type Config struct {
// Enables seccomp inside the sandbox.
OCISeccomp bool `flag:"oci-seccomp"`
+ // Mounts the cgroup filesystem backed by the sentry's cgroupfs.
+ Cgroupfs bool `flag:"cgroupfs"`
+
// TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
// tests. It allows runsc to start the sandbox process as the current
// user, and without chrooting the sandbox process. This can be
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index 1d996c841..c3dca2352 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -44,7 +44,8 @@ func RegisterFlags() {
// Debugging flags.
flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.")
- flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.")
+ flag.String("panic-log", "", "file path where panic reports and other Go's runtime messages are written.")
+ flag.String("coverage-report", "", "file path where Go coverage reports are written. Reports will only be generated if runsc is built with --collect_code_coverage and --instrumentation_filter Bazel flags.")
flag.Bool("log-packets", false, "enable network packet logging.")
flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
flag.Bool("alsologtostderr", false, "send log messages to stderr.")
@@ -75,6 +76,7 @@ func RegisterFlags() {
flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.")
flag.Bool("vfs2", false, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.")
flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.")
+ flag.Bool("cgroupfs", false, "Automatically mount cgroupfs.")
// Flags that control sandbox runtime behavior: network related.
flag.Var(networkTypePtr(NetworkSandbox), "network", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 3620dc8c3..5314549d6 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -51,9 +51,7 @@ go_test(
],
library = ":container",
shard_count = more_shards,
- tags = [
- "requires-kvm",
- ],
+ tags = ["requires-kvm"],
deps = [
"//pkg/abi/linux",
"//pkg/bits",
diff --git a/runsc/container/container.go b/runsc/container/container.go
index f9d83c118..e72ada311 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -886,7 +886,7 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu
// Add root mount and then add any other additional mounts.
mountCount := 1
for _, m := range spec.Mounts {
- if specutils.Is9PMount(m) {
+ if specutils.Is9PMount(m, conf.VFS2) {
mountCount++
}
}
diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go
index 24f67414c..88409af8f 100644
--- a/runsc/mitigate/mitigate.go
+++ b/runsc/mitigate/mitigate.go
@@ -50,7 +50,7 @@ const (
type CPUSet map[threadID]*ThreadGroup
// NewCPUSet creates a CPUSet from data read from /proc/cpuinfo.
-func NewCPUSet(data []byte, vulnerable func(Thread) bool) (CPUSet, error) {
+func NewCPUSet(data []byte) (CPUSet, error) {
processors, err := getThreads(string(data))
if err != nil {
return nil, err
@@ -67,7 +67,7 @@ func NewCPUSet(data []byte, vulnerable func(Thread) bool) (CPUSet, error) {
core = &ThreadGroup{}
set[p.id] = core
}
- core.isVulnerable = core.isVulnerable || vulnerable(p)
+ core.isVulnerable = core.isVulnerable || p.IsVulnerable()
core.threads = append(core.threads, p)
}
@@ -446,6 +446,7 @@ func buildRegex(key, match string) *regexp.Regexp {
func parseRegex(data, key, match string) (string, error) {
r := buildRegex(key, match)
matches := r.FindStringSubmatch(data)
+
if len(matches) < 2 {
return "", fmt.Errorf("failed to match key %q: %q", key, data)
}
diff --git a/runsc/mitigate/mitigate_test.go b/runsc/mitigate/mitigate_test.go
index fbd8eb886..3bf9ef547 100644
--- a/runsc/mitigate/mitigate_test.go
+++ b/runsc/mitigate/mitigate_test.go
@@ -52,14 +52,13 @@ func TestMockCPUSet(t *testing.T) {
} {
t.Run(tc.testCase.Name, func(t *testing.T) {
data := tc.testCase.MakeCPUString()
- vulnerable := func(t Thread) bool {
- return t.IsVulnerable()
- }
- set, err := NewCPUSet([]byte(data), vulnerable)
+ set, err := NewCPUSet([]byte(data))
if err != nil {
t.Fatalf("Failed to create cpuSet: %v", err)
}
+ t.Logf("data: %s", data)
+
for _, tg := range set {
if err := checkSorted(tg.threads); err != nil {
t.Fatalf("Failed to sort cpuSet: %v", err)
@@ -258,11 +257,7 @@ func TestReadFile(t *testing.T) {
t.Fatalf("Failed to read cpuinfo: %v", err)
}
- vulnerable := func(t Thread) bool {
- return t.IsVulnerable()
- }
-
- set, err := NewCPUSet(data, vulnerable)
+ set, err := NewCPUSet(data)
if err != nil {
t.Fatalf("Failed to parse CPU data %v\n%s", err, data)
}
diff --git a/runsc/mitigate/mock/mock.go b/runsc/mitigate/mock/mock.go
index 2db718cb9..12c59e356 100644
--- a/runsc/mitigate/mock/mock.go
+++ b/runsc/mitigate/mock/mock.go
@@ -82,6 +82,19 @@ var Haswell2core = CPU{
ThreadsPerCore: 1,
}
+// AMD2 is an two core AMD machine.
+var AMD2 = CPU{
+ Name: "AMD",
+ VendorID: "AuthenticAMD",
+ Family: 23,
+ Model: 49,
+ ModelName: "AMD EPYC 7B12",
+ Bugs: "sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass",
+ PhysicalCores: 1,
+ Cores: 1,
+ ThreadsPerCore: 2,
+}
+
// AMD8 is an eight core AMD machine.
var AMD8 = CPU{
Name: "AMD",
@@ -115,15 +128,15 @@ bugs : %s
for k := 0; k < tc.ThreadsPerCore; k++ {
processorNum := (i*tc.Cores+j)*tc.ThreadsPerCore + k
ret += fmt.Sprintf(template,
- processorNum, /*processor*/
- tc.VendorID, /*vendor_id*/
- tc.Family, /*cpu family*/
- tc.Model, /*model*/
- tc.ModelName, /*model name*/
- i, /*physical id*/
- j, /*core id*/
- tc.Cores*tc.PhysicalCores, /*cpu cores*/
- tc.Bugs, /*bugs*/
+ processorNum, /*processor*/
+ tc.VendorID, /*vendor_id*/
+ tc.Family, /*cpu family*/
+ tc.Model, /*model*/
+ tc.ModelName, /*model name*/
+ i, /*physical id*/
+ j, /*core id*/
+ k, /*cpu cores*/
+ tc.Bugs, /*bugs*/
)
}
}
diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD
index f0a551a1e..bc4a3fa32 100644
--- a/runsc/sandbox/BUILD
+++ b/runsc/sandbox/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/cleanup",
"//pkg/control/client",
"//pkg/control/server",
+ "//pkg/coverage",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/platform",
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 450f92645..f3f60f116 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -34,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/control/client"
"gvisor.dev/gvisor/pkg/control/server"
+ "gvisor.dev/gvisor/pkg/coverage"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/platform"
@@ -399,15 +400,15 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
cmd.Args = append(cmd.Args, "--log-fd="+strconv.Itoa(nextFD))
nextFD++
}
- if conf.DebugLog != "" {
- test := ""
- if len(conf.TestOnlyTestNameEnv) != 0 {
- // Fetch test name if one is provided and the test only flag was set.
- if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
- test = t
- }
- }
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+ if conf.DebugLog != "" {
debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot", test)
if err != nil {
return fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
@@ -418,23 +419,29 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
nextFD++
}
if conf.PanicLog != "" {
- test := ""
- if len(conf.TestOnlyTestNameEnv) != 0 {
- // Fetch test name if one is provided and the test only flag was set.
- if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
- test = t
- }
- }
-
panicLogFile, err := specutils.DebugLogFile(conf.PanicLog, "panic", test)
if err != nil {
- return fmt.Errorf("opening debug log file in %q: %v", conf.PanicLog, err)
+ return fmt.Errorf("opening panic log file in %q: %v", conf.PanicLog, err)
}
defer panicLogFile.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, panicLogFile)
cmd.Args = append(cmd.Args, "--panic-log-fd="+strconv.Itoa(nextFD))
nextFD++
}
+ covFilename := conf.CoverageReport
+ if covFilename == "" {
+ covFilename = os.Getenv("GO_COVERAGE_FILE")
+ }
+ if covFilename != "" && coverage.Available() {
+ covFile, err := specutils.DebugLogFile(covFilename, "cov", test)
+ if err != nil {
+ return fmt.Errorf("opening debug log file in %q: %v", covFilename, err)
+ }
+ defer covFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, covFile)
+ cmd.Args = append(cmd.Args, "--coverage-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
// Add the "boot" command to the args.
//
@@ -486,7 +493,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
}
if deviceFile, err := gPlatform.OpenDevice(); err != nil {
- return fmt.Errorf("opening device file for platform %q: %v", gPlatform, err)
+ return fmt.Errorf("opening device file for platform %q: %v", conf.Platform, err)
} else if deviceFile != nil {
defer deviceFile.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, deviceFile)
@@ -1174,7 +1181,7 @@ func deviceFileForPlatform(name string) (*os.File, error) {
f, err := p.OpenDevice()
if err != nil {
- return nil, fmt.Errorf("opening device file for platform %q: %v", p, err)
+ return nil, fmt.Errorf("opening device file for platform %q: %w", name, err)
}
return f, nil
}
diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go
index b62504a8c..9ecd0fde6 100644
--- a/runsc/specutils/fs.go
+++ b/runsc/specutils/fs.go
@@ -18,6 +18,7 @@ import (
"fmt"
"math/bits"
"path"
+ "strings"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
@@ -64,6 +65,12 @@ var optionsMap = map[string]mapping{
"sync": {set: true, val: unix.MS_SYNCHRONOUS},
}
+// verityMountOptions is the set of valid verity mount option keys.
+var verityMountOptions = map[string]struct{}{
+ "verity.roothash": struct{}{},
+ "verity.action": struct{}{},
+}
+
// propOptionsMap is similar to optionsMap, but it lists propagation options
// that cannot be used together with other flags.
var propOptionsMap = map[string]mapping{
@@ -117,6 +124,14 @@ func validateMount(mnt *specs.Mount) error {
return nil
}
+func moptKey(opt string) string {
+ if len(opt) == 0 {
+ return opt
+ }
+ // Guaranteed to have at least one token, since opt is not empty.
+ return strings.SplitN(opt, "=", 2)[0]
+}
+
// ValidateMountOptions validates that mount options are correct.
func ValidateMountOptions(opts []string) error {
for _, o := range opts {
@@ -125,7 +140,8 @@ func ValidateMountOptions(opts []string) error {
}
_, ok1 := optionsMap[o]
_, ok2 := propOptionsMap[o]
- if !ok1 && !ok2 {
+ _, ok3 := verityMountOptions[moptKey(o)]
+ if !ok1 && !ok2 && !ok3 {
return fmt.Errorf("unknown mount option %q", o)
}
if err := validatePropagation(o); err != nil {
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 45856fd58..e5e66546c 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -332,14 +332,20 @@ func capsFromNames(names []string, skipSet map[linux.Capability]struct{}) (auth.
return auth.CapabilitySetOfMany(caps), nil
}
-// Is9PMount returns true if the given mount can be mounted as an external gofer.
-func Is9PMount(m specs.Mount) bool {
- return m.Type == "bind" && m.Source != "" && IsVFS1SupportedDevMount(m)
+// Is9PMount returns true if the given mount can be mounted as an external
+// gofer.
+func Is9PMount(m specs.Mount, vfs2Enabled bool) bool {
+ return m.Type == "bind" && m.Source != "" && IsSupportedDevMount(m, vfs2Enabled)
}
-// IsVFS1SupportedDevMount returns true if m.Destination does not specify a
+// IsSupportedDevMount returns true if m.Destination does not specify a
// path that is hardcoded by VFS1's implementation of /dev.
-func IsVFS1SupportedDevMount(m specs.Mount) bool {
+func IsSupportedDevMount(m specs.Mount, vfs2Enabled bool) bool {
+ // VFS2 has no hardcoded files under /dev, so everything is allowed.
+ if vfs2Enabled {
+ return true
+ }
+
// See pkg/sentry/fs/dev/dev.go.
var existingDevices = []string{
"/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr",
diff --git a/shim/BUILD b/shim/BUILD
index 434269d31..695f61eb9 100644
--- a/shim/BUILD
+++ b/shim/BUILD
@@ -6,6 +6,7 @@ go_binary(
name = "containerd-shim-runsc-v1",
srcs = ["main.go"],
static = True,
+ tags = ["staging"],
visibility = [
"//visibility:public",
],
diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD
index 697ab5837..a5a3cf2c1 100644
--- a/test/benchmarks/base/BUILD
+++ b/test/benchmarks/base/BUILD
@@ -17,7 +17,6 @@ go_library(
benchmark_test(
name = "startup_test",
- size = "enormous",
srcs = ["startup_test.go"],
visibility = ["//:sandbox"],
deps = [
@@ -29,7 +28,6 @@ benchmark_test(
benchmark_test(
name = "size_test",
- size = "enormous",
srcs = ["size_test.go"],
visibility = ["//:sandbox"],
deps = [
@@ -42,7 +40,6 @@ benchmark_test(
benchmark_test(
name = "sysbench_test",
- size = "enormous",
srcs = ["sysbench_test.go"],
visibility = ["//:sandbox"],
deps = [
diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD
index 0b1743603..fee2695ff 100644
--- a/test/benchmarks/database/BUILD
+++ b/test/benchmarks/database/BUILD
@@ -11,7 +11,6 @@ go_library(
benchmark_test(
name = "redis_test",
- size = "enormous",
srcs = ["redis_test.go"],
library = ":database",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD
index dc82e63b2..c2b981a07 100644
--- a/test/benchmarks/fs/BUILD
+++ b/test/benchmarks/fs/BUILD
@@ -4,7 +4,6 @@ package(licenses = ["notice"])
benchmark_test(
name = "bazel_test",
- size = "enormous",
srcs = ["bazel_test.go"],
visibility = ["//:sandbox"],
deps = [
@@ -18,7 +17,6 @@ benchmark_test(
benchmark_test(
name = "fio_test",
- size = "enormous",
srcs = ["fio_test.go"],
visibility = ["//:sandbox"],
deps = [
diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD
index 380783f0b..ad2ef3a55 100644
--- a/test/benchmarks/media/BUILD
+++ b/test/benchmarks/media/BUILD
@@ -11,7 +11,6 @@ go_library(
benchmark_test(
name = "ffmpeg_test",
- size = "enormous",
srcs = ["ffmpeg_test.go"],
library = ":media",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD
index 3425b8dad..56a4d4f39 100644
--- a/test/benchmarks/ml/BUILD
+++ b/test/benchmarks/ml/BUILD
@@ -11,7 +11,6 @@ go_library(
benchmark_test(
name = "tensorflow_test",
- size = "enormous",
srcs = ["tensorflow_test.go"],
library = ":ml",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD
index 2741570f5..e047020bf 100644
--- a/test/benchmarks/network/BUILD
+++ b/test/benchmarks/network/BUILD
@@ -18,7 +18,6 @@ go_library(
benchmark_test(
name = "iperf_test",
- size = "enormous",
srcs = [
"iperf_test.go",
],
@@ -34,7 +33,6 @@ benchmark_test(
benchmark_test(
name = "node_test",
- size = "enormous",
srcs = [
"node_test.go",
],
@@ -49,7 +47,6 @@ benchmark_test(
benchmark_test(
name = "ruby_test",
- size = "enormous",
srcs = [
"ruby_test.go",
],
@@ -64,7 +61,6 @@ benchmark_test(
benchmark_test(
name = "nginx_test",
- size = "enormous",
srcs = [
"nginx_test.go",
],
@@ -79,7 +75,6 @@ benchmark_test(
benchmark_test(
name = "httpd_test",
- size = "enormous",
srcs = [
"httpd_test.go",
],
diff --git a/test/e2e/BUILD b/test/e2e/BUILD
index 29a84f184..1e9792b4f 100644
--- a/test/e2e/BUILD
+++ b/test/e2e/BUILD
@@ -8,13 +8,12 @@ go_test(
srcs = [
"exec_test.go",
"integration_test.go",
- "regression_test.go",
],
library = ":integration",
tags = [
# Requires docker and runsc to be configured before the test runs.
- "manual",
"local",
+ "manual",
],
visibility = ["//:sandbox"],
deps = [
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 49cd74887..1accc3b3b 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -168,13 +168,6 @@ func TestCheckpointRestore(t *testing.T) {
t.Skip("Pause/resume is not supported.")
}
- // TODO(gvisor.dev/issue/3373): Remove after implementing.
- if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 {
- t.Skip("CheckpointRestore not implemented in VFS2.")
- } else if err != nil {
- t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err)
- }
-
ctx := context.Background()
d := dockerutil.MakeContainer(ctx, t)
defer d.CleanUp(ctx)
@@ -399,15 +392,15 @@ func TestTmpFile(t *testing.T) {
// TestTmpMount checks that mounts inside '/tmp' are not overridden.
func TestTmpMount(t *testing.T) {
- ctx := context.Background()
dir, err := ioutil.TempDir(testutil.TmpDir(), "tmp-mount")
if err != nil {
t.Fatalf("TempDir(): %v", err)
}
- want := "123"
+ const want = "123"
if err := ioutil.WriteFile(filepath.Join(dir, "file.txt"), []byte("123"), 0666); err != nil {
t.Fatalf("WriteFile(): %v", err)
}
+ ctx := context.Background()
d := dockerutil.MakeContainer(ctx, t)
defer d.CleanUp(ctx)
@@ -430,6 +423,48 @@ func TestTmpMount(t *testing.T) {
}
}
+// Test that it is allowed to mount a file on top of /dev files, e.g.
+// /dev/random.
+func TestMountOverDev(t *testing.T) {
+ if usingVFS2, err := dockerutil.UsingVFS2(); !usingVFS2 {
+ t.Skip("VFS1 doesn't allow /dev/random to be mounted.")
+ } else if err != nil {
+ t.Fatalf("Failed to read config for runtime %s: %v", dockerutil.Runtime(), err)
+ }
+
+ random, err := ioutil.TempFile(testutil.TmpDir(), "random")
+ if err != nil {
+ t.Fatal("ioutil.TempFile() failed:", err)
+ }
+ const want = "123"
+ if _, err := random.WriteString(want); err != nil {
+ t.Fatalf("WriteString() to %q: %v", random.Name(), err)
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ Mounts: []mount.Mount{
+ {
+ Type: mount.TypeBind,
+ Source: random.Name(),
+ Target: "/dev/random",
+ },
+ },
+ }
+ cmd := "dd count=1 bs=5 if=/dev/random 2> /dev/null"
+ got, err := d.Run(ctx, opts, "sh", "-c", cmd)
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want != got {
+ t.Errorf("invalid file content, want: %q, got: %q", want, got)
+ }
+}
+
// TestSyntheticDirs checks that submounts can be created inside a readonly
// mount even if the target path does not exist.
func TestSyntheticDirs(t *testing.T) {
@@ -550,6 +585,30 @@ func runIntegrationTest(t *testing.T, capAdd []string, args ...string) {
}
}
+// Test that UDS can be created using overlay when parent directory is in lower
+// layer only (b/134090485).
+//
+// Prerequisite: the directory where the socket file is created must not have
+// been open for write before bind(2) is called.
+func TestBindOverlay(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Run the container.
+ got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Check the output contains what we want.
+ if want := "foobar-asdf"; !strings.Contains(got, want) {
+ t.Fatalf("docker run output is missing %q: %s", want, got)
+ }
+}
+
func TestMain(m *testing.M) {
dockerutil.EnsureSupportedDockerVersion()
flag.Parse()
diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go
deleted file mode 100644
index 84564cdaa..000000000
--- a/test/e2e/regression_test.go
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package integration
-
-import (
- "context"
- "strings"
- "testing"
-
- "gvisor.dev/gvisor/pkg/test/dockerutil"
-)
-
-// Test that UDS can be created using overlay when parent directory is in lower
-// layer only (b/134090485).
-//
-// Prerequisite: the directory where the socket file is created must not have
-// been open for write before bind(2) is called.
-func TestBindOverlay(t *testing.T) {
- ctx := context.Background()
- d := dockerutil.MakeContainer(ctx, t)
- defer d.CleanUp(ctx)
-
- // Run the container.
- got, err := d.Run(ctx, dockerutil.RunOpts{
- Image: "basic/ubuntu",
- }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p")
- if err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
-
- // Check the output contains what we want.
- if want := "foobar-asdf"; !strings.Contains(got, want) {
- t.Fatalf("docker run output is missing %q: %s", want, got)
- }
-}
diff --git a/test/fsstress/BUILD b/test/fsstress/BUILD
index d262c8554..e74e7fff2 100644
--- a/test/fsstress/BUILD
+++ b/test/fsstress/BUILD
@@ -14,9 +14,7 @@ go_test(
"manual",
"local",
],
- deps = [
- "//pkg/test/dockerutil",
- ],
+ deps = ["//pkg/test/dockerutil"],
)
go_library(
diff --git a/test/fsstress/fsstress_test.go b/test/fsstress/fsstress_test.go
index 300c21ceb..d53c8f90d 100644
--- a/test/fsstress/fsstress_test.go
+++ b/test/fsstress/fsstress_test.go
@@ -17,7 +17,9 @@ package fsstress
import (
"context"
+ "flag"
"math/rand"
+ "os"
"strconv"
"strings"
"testing"
@@ -30,33 +32,44 @@ func init() {
rand.Seed(int64(time.Now().Nanosecond()))
}
-func fsstress(t *testing.T, dir string) {
+func TestMain(m *testing.M) {
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
+ os.Exit(m.Run())
+}
+
+type config struct {
+ operations string
+ processes string
+ target string
+}
+
+func fsstress(t *testing.T, conf config) {
ctx := context.Background()
d := dockerutil.MakeContainer(ctx, t)
defer d.CleanUp(ctx)
- const (
- operations = "10000"
- processes = "100"
- image = "basic/fsstress"
- )
+ const image = "basic/fsstress"
seed := strconv.FormatUint(uint64(rand.Uint32()), 10)
- args := []string{"-d", dir, "-n", operations, "-p", processes, "-s", seed, "-X"}
- t.Logf("Repro: docker run --rm --runtime=runsc %s %s", image, strings.Join(args, ""))
+ args := []string{"-d", conf.target, "-n", conf.operations, "-p", conf.processes, "-s", seed, "-X"}
+ t.Logf("Repro: docker run --rm --runtime=%s gvisor.dev/images/%s %s", dockerutil.Runtime(), image, strings.Join(args, " "))
out, err := d.Run(ctx, dockerutil.RunOpts{Image: image}, args...)
if err != nil {
t.Fatalf("docker run failed: %v\noutput: %s", err, out)
}
- lines := strings.SplitN(out, "\n", 2)
- if len(lines) > 1 || !strings.HasPrefix(out, "seed =") {
+ // This is to catch cases where fsstress spews out error messages during clean
+ // up but doesn't return error.
+ if len(out) > 0 {
t.Fatalf("unexpected output: %s", out)
}
}
-func TestFsstressGofer(t *testing.T) {
- fsstress(t, "/test")
-}
-
func TestFsstressTmpfs(t *testing.T) {
- fsstress(t, "/tmp")
+ // This takes between 10s to run on my machine. Adjust as needed.
+ cfg := config{
+ operations: "5000",
+ processes: "20",
+ target: "/tmp",
+ }
+ fsstress(t, cfg)
}
diff --git a/test/image/image_test.go b/test/image/image_test.go
index 968e62f63..952264173 100644
--- a/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -183,7 +183,10 @@ func TestMysql(t *testing.T) {
// Start the container.
if err := server.Spawn(ctx, dockerutil.RunOpts{
Image: "basic/mysql",
- Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"},
+ Env: []string{
+ "MYSQL_ROOT_PASSWORD=foobar123",
+ "MYSQL_ROOT_HOST=%", // Allow anyone to connect to the server.
+ },
}); err != nil {
t.Fatalf("docker run failed: %v", err)
}
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
index 94d4ca2d4..9805665ac 100644
--- a/test/iptables/BUILD
+++ b/test/iptables/BUILD
@@ -16,8 +16,8 @@ go_library(
visibility = ["//test/iptables:__subpackages__"],
deps = [
"//pkg/binary",
+ "//pkg/hostarch",
"//pkg/test/testutil",
- "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index d6c69a319..04d112134 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -456,3 +456,11 @@ func TestNATPreRECVORIGDSTADDR(t *testing.T) {
func TestNATOutRECVORIGDSTADDR(t *testing.T) {
singleTest(t, &NATOutRECVORIGDSTADDR{})
}
+
+func TestNATPostSNATUDP(t *testing.T) {
+ singleTest(t, &NATPostSNATUDP{})
+}
+
+func TestNATPostSNATTCP(t *testing.T) {
+ singleTest(t, &NATPostSNATTCP{})
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index bba17b894..4590e169d 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -69,29 +69,41 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error {
return nil
}
-// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
-// the first read on that port.
+// listenUDP listens on a UDP port and returns nil if the first read from that
+// port is successful.
func listenUDP(ctx context.Context, port int, ipv6 bool) error {
+ _, err := listenUDPFrom(ctx, port, ipv6)
+ return err
+}
+
+// listenUDPFrom listens on a UDP port and returns the sender's UDP address if
+// the first read from that port is successful.
+func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) {
localAddr := net.UDPAddr{
Port: port,
}
conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr)
if err != nil {
- return err
+ return nil, err
}
defer conn.Close()
- ch := make(chan error)
+ type result struct {
+ remoteAddr *net.UDPAddr
+ err error
+ }
+
+ ch := make(chan result)
go func() {
- _, err = conn.Read([]byte{0})
- ch <- err
+ _, remoteAddr, err := conn.ReadFromUDP([]byte{0})
+ ch <- result{remoteAddr, err}
}()
select {
- case err := <-ch:
- return err
+ case res := <-ch:
+ return res.remoteAddr, res.err
case <-ctx.Done():
- return ctx.Err()
+ return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err())
}
}
@@ -125,8 +137,16 @@ func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error {
}
}
-// listenTCP listens for connections on a TCP port.
+// listenTCP listens for connections on a TCP port, and returns nil if a
+// connection is established.
func listenTCP(ctx context.Context, port int, ipv6 bool) error {
+ _, err := listenTCPFrom(ctx, port, ipv6)
+ return err
+}
+
+// listenTCP listens for connections on a TCP port, and returns the remote
+// TCP address if a connection is established.
+func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) {
localAddr := net.TCPAddr{
Port: port,
}
@@ -134,23 +154,32 @@ func listenTCP(ctx context.Context, port int, ipv6 bool) error {
// Starts listening on port.
lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr)
if err != nil {
- return err
+ return nil, err
}
defer lConn.Close()
+ type result struct {
+ remoteAddr net.Addr
+ err error
+ }
+
// Accept connections on port.
- ch := make(chan error)
+ ch := make(chan result)
go func() {
conn, err := lConn.AcceptTCP()
- ch <- err
+ var remoteAddr net.Addr
+ if err == nil {
+ remoteAddr = conn.RemoteAddr()
+ }
+ ch <- result{remoteAddr, err}
conn.Close()
}()
select {
- case err := <-ch:
- return err
+ case res := <-ch:
+ return res.remoteAddr, res.err
case <-ctx.Done():
- return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err())
+ return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err())
}
}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 70d8a1832..0f25b6a18 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -19,10 +19,11 @@ import (
"errors"
"fmt"
"net"
+ "strconv"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
const redirectPort = 42
@@ -48,6 +49,8 @@ func init() {
RegisterTestCase(&NATOutOriginalDst{})
RegisterTestCase(&NATPreRECVORIGDSTADDR{})
RegisterTestCase(&NATOutRECVORIGDSTADDR{})
+ RegisterTestCase(&NATPostSNATUDP{})
+ RegisterTestCase(&NATPostSNATTCP{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@@ -486,7 +489,12 @@ func (*NATLoopbackSkipsPrerouting) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (*NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
// Redirect anything sent to localhost to an unused port.
- dest := []byte{127, 0, 0, 1}
+ var dest net.IP
+ if ipv6 {
+ dest = net.IPv6loopback
+ } else {
+ dest = net.IPv4(127, 0, 0, 1)
+ }
if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
return err
}
@@ -848,7 +856,7 @@ func recvOrigDstAddr4(sockfd int) (unix.RawSockaddrInet4, error) {
return unix.RawSockaddrInet4{}, err
}
var addr unix.RawSockaddrInet4
- binary.Unmarshal(buf, usermem.ByteOrder, &addr)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &addr)
return addr, nil
}
@@ -858,7 +866,7 @@ func recvOrigDstAddr6(sockfd int) (unix.RawSockaddrInet6, error) {
return unix.RawSockaddrInet6{}, err
}
var addr unix.RawSockaddrInet6
- binary.Unmarshal(buf, usermem.ByteOrder, &addr)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &addr)
return addr, nil
}
@@ -915,3 +923,115 @@ func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) er
}
return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
}
+
+const (
+ snatAddrV4 = "194.236.50.155"
+ snatAddrV6 = "2a0a::1"
+ snatPort = 43
+)
+
+// NATPostSNATUDP tests that the source port/IP in the packets are modified as expected.
+type NATPostSNATUDP struct{ localCase }
+
+var _ TestCase = (*NATPostSNATUDP)(nil)
+
+// Name implements TestCase.Name.
+func (*NATPostSNATUDP) Name() string {
+ return "NATPostSNATUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (*NATPostSNATUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ var source string
+ if ipv6 {
+ source = fmt.Sprintf("[%s]:%d", snatAddrV6, snatPort)
+ } else {
+ source = fmt.Sprintf("%s:%d", snatAddrV4, snatPort)
+ }
+
+ if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "udp", "-j", "SNAT", "--to-source", source); err != nil {
+ return err
+ }
+ return sendUDPLoop(ctx, ip, acceptPort, ipv6)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (*NATPostSNATUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ remote, err := listenUDPFrom(ctx, acceptPort, ipv6)
+ if err != nil {
+ return err
+ }
+ var snatAddr string
+ if ipv6 {
+ snatAddr = snatAddrV6
+ } else {
+ snatAddr = snatAddrV4
+ }
+ if got, want := remote.IP, net.ParseIP(snatAddr); !got.Equal(want) {
+ return fmt.Errorf("got remote address = %s, want = %s", got, want)
+ }
+ if got, want := remote.Port, snatPort; got != want {
+ return fmt.Errorf("got remote port = %d, want = %d", got, want)
+ }
+ return nil
+}
+
+// NATPostSNATTCP tests that the source port/IP in the packets are modified as
+// expected.
+type NATPostSNATTCP struct{ localCase }
+
+var _ TestCase = (*NATPostSNATTCP)(nil)
+
+// Name implements TestCase.Name.
+func (*NATPostSNATTCP) Name() string {
+ return "NATPostSNATTCP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (*NATPostSNATTCP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ addrs, err := getInterfaceAddrs(ipv6)
+ if err != nil {
+ return err
+ }
+ var source string
+ for _, addr := range addrs {
+ if addr.To4() != nil {
+ if !ipv6 {
+ source = fmt.Sprintf("%s:%d", addr, snatPort)
+ }
+ } else if ipv6 && addr.IsGlobalUnicast() {
+ source = fmt.Sprintf("[%s]:%d", addr, snatPort)
+ }
+ }
+ if source == "" {
+ return fmt.Errorf("can't find any interface address to use")
+ }
+
+ if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "tcp", "-j", "SNAT", "--to-source", source); err != nil {
+ return err
+ }
+ return connectTCP(ctx, ip, acceptPort, ipv6)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (*NATPostSNATTCP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ remote, err := listenTCPFrom(ctx, acceptPort, ipv6)
+ if err != nil {
+ return err
+ }
+ HostStr, portStr, err := net.SplitHostPort(remote.String())
+ if err != nil {
+ return err
+ }
+ if got, want := HostStr, ip.String(); got != want {
+ return fmt.Errorf("got remote address = %s, want = %s", got, want)
+ }
+ port, err := strconv.ParseInt(portStr, 10, 0)
+ if err != nil {
+ return err
+ }
+ if got, want := int(port), snatPort; got != want {
+ return fmt.Errorf("got remote port = %d, want = %d", got, want)
+ }
+ return nil
+}
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
index 5d95516ee..de66cbe6d 100644
--- a/test/packetdrill/BUILD
+++ b/test/packetdrill/BUILD
@@ -41,6 +41,7 @@ packetdrill_test(
test_suite(
name = "all_tests",
tags = [
+ "local",
"manual",
"packetdrill",
],
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index 34e83ec49..634c15727 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -246,6 +246,12 @@ ALL_TESTS = [
expect_netstack_failure = True,
),
PacketimpactTestInfo(
+ name = "tcp_listen_backlog",
+ ),
+ PacketimpactTestInfo(
+ name = "tcp_syncookie",
+ ),
+ PacketimpactTestInfo(
name = "icmpv6_param_problem",
),
PacketimpactTestInfo(
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
index b271bd47e..4fb2f5c4b 100644
--- a/test/packetimpact/runner/dut.go
+++ b/test/packetimpact/runner/dut.go
@@ -369,30 +369,32 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
"--dut_infos_json", string(dutInfosBytes),
)
testbenchLogs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
- if (err != nil) != expectFailure {
- var dutLogs string
- for i, dut := range duts {
- logs, err := dut.Logs(ctx)
- if err != nil {
- logs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
- }
- dutLogs = fmt.Sprintf(`%s====== Begin of DUT-%d Logs ======
+ var dutLogs string
+ for i, dut := range duts {
+ logs, err := dut.Logs(ctx)
+ if err != nil {
+ logs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
+ }
+ dutLogs = fmt.Sprintf(`%s====== Begin of DUT-%d Logs ======
%s
====== End of DUT-%d Logs ======
`, dutLogs, i, logs, i)
- }
-
- t.Errorf(`test error: %v, expect failure: %t
-
+ }
+ testLogs := fmt.Sprintf(`
%s====== Begin of Testbench Logs ======
%s
-====== End of Testbench Logs ======`,
- err, expectFailure, dutLogs, testbenchLogs)
+====== End of Testbench Logs ======`, dutLogs, testbenchLogs)
+ if (err != nil) != expectFailure {
+ t.Errorf(`test error: %v, expect failure: %t
+%s`, err, expectFailure, testLogs)
+ } else if expectFailure {
+ t.Logf(`test failed as expected: %v
+%s`, err, testLogs)
}
}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 43b4c7ca1..616215dc3 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -16,11 +16,11 @@ go_library(
],
visibility = ["//test/packetimpact:__subpackages__"],
deps = [
+ "//pkg/hostarch",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
- "//pkg/usermem",
"//test/packetimpact/proto:posix_server_go_proto",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 1ac96626a..feeb0888a 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -23,7 +23,7 @@ import (
"time"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
)
// Sniffer can sniff raw packets on the wire.
@@ -34,7 +34,7 @@ type Sniffer struct {
func htons(x uint16) uint16 {
buf := [2]byte{}
binary.BigEndian.PutUint16(buf[:], x)
- return usermem.ByteOrder.Uint16(buf[:])
+ return hostarch.ByteOrder.Uint16(buf[:])
}
// NewSniffer creates a Sniffer connected to *device.
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index c0deb33e5..e015c1f0e 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -105,8 +105,8 @@ packetimpact_testbench(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/hostarch",
"//pkg/tcpip/header",
- "//pkg/usermem",
"//test/packetimpact/testbench",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -354,9 +354,9 @@ packetimpact_testbench(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/hostarch",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
- "//pkg/usermem",
"//test/packetimpact/testbench",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -368,8 +368,8 @@ packetimpact_testbench(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/hostarch",
"//pkg/tcpip/header",
- "//pkg/usermem",
"//test/packetimpact/testbench",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -385,6 +385,26 @@ packetimpact_testbench(
],
)
+packetimpact_testbench(
+ name = "tcp_listen_backlog",
+ srcs = ["tcp_listen_backlog_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_testbench(
+ name = "tcp_syncookie",
+ srcs = ["tcp_syncookie_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
validate_all_tests()
[packetimpact_go_test(
@@ -396,6 +416,7 @@ validate_all_tests()
test_suite(
name = "all_tests",
tags = [
+ "local",
"manual",
"packetimpact",
],
diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go
index 3fc2c7fe5..93f58ec49 100644
--- a/test/packetimpact/tests/tcp_info_test.go
+++ b/test/packetimpact/tests/tcp_info_test.go
@@ -22,8 +22,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/test/packetimpact/testbench"
)
@@ -58,7 +58,7 @@ func TestTCPInfo(t *testing.T) {
if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want {
t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want)
}
- binary.Unmarshal(infoBytes, usermem.ByteOrder, &info)
+ binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info)
rtt := time.Duration(info.RTT) * time.Microsecond
rttvar := time.Duration(info.RTTVar) * time.Microsecond
@@ -99,7 +99,7 @@ func TestTCPInfo(t *testing.T) {
if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want {
t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want)
}
- binary.Unmarshal(infoBytes, usermem.ByteOrder, &info)
+ binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info)
if info.CaState != linux.TCP_CA_Loss {
t.Errorf("expected the connection to be in loss recovery, got: %v want: %v", info.CaState, linux.TCP_CA_Loss)
}
diff --git a/test/packetimpact/tests/tcp_listen_backlog_test.go b/test/packetimpact/tests/tcp_listen_backlog_test.go
new file mode 100644
index 000000000..26c812d0a
--- /dev/null
+++ b/test/packetimpact/tests/tcp_listen_backlog_test.go
@@ -0,0 +1,86 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_listen_backlog_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestTCPListenBacklog tests for a listening endpoint behavior:
+// (1) reply to more SYNs than what is configured as listen backlog
+// (2) ignore ACKs (that complete a handshake) when the accept queue is full
+// (3) ignore incoming SYNs when the accept queue is full
+func TestTCPListenBacklog(t *testing.T) {
+ dut := testbench.NewDUT(t)
+
+ // Listening endpoint accepts one more connection than the listen backlog.
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 0 /*backlog*/)
+
+ var establishedConn testbench.TCPIPv4
+ var incompleteConn testbench.TCPIPv4
+
+ // Test if the DUT listener replies to more SYNs than listen backlog+1
+ for i, conn := range []*testbench.TCPIPv4{&establishedConn, &incompleteConn} {
+ *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ // Expect dut connection to have transitioned to SYN-RCVD state.
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK for %d connection, %s", i, err)
+ }
+ }
+ defer establishedConn.Close(t)
+ defer incompleteConn.Close(t)
+
+ // Send the ACK to complete handshake.
+ establishedConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ dut.PollOne(t, listenFd, unix.POLLIN, time.Second)
+
+ // Send the ACK to complete handshake, expect this to be ignored by the
+ // listener.
+ incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+
+ // Drain the accept queue to enable poll for subsequent connections on the
+ // listener.
+ dut.Accept(t, listenFd)
+
+ // The ACK for the incomplete connection should be ignored by the
+ // listening endpoint and the poll on listener should now time out.
+ if pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, time.Second); len(pfds) != 0 {
+ t.Fatalf("got dut.Poll(...) = %#v", pfds)
+ }
+
+ // Re-send the ACK to complete handshake and re-fill the accept-queue.
+ incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ dut.PollOne(t, listenFd, unix.POLLIN, time.Second)
+
+ // Now initiate a new connection when the accept queue is full.
+ connectingConn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer connectingConn.Close(t)
+ // Expect dut connection to drop the SYN and let the client stay in SYN_SENT state.
+ connectingConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if got, err := connectingConn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err == nil {
+ t.Fatalf("expected no SYN-ACK, but got %s", got)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go
index 0a5b0f12b..ff1431bbf 100644
--- a/test/packetimpact/tests/tcp_rack_test.go
+++ b/test/packetimpact/tests/tcp_rack_test.go
@@ -22,9 +22,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/test/packetimpact/testbench"
)
@@ -74,7 +74,7 @@ func getRTTAndRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rtt, rto tim
if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want {
t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want)
}
- binary.Unmarshal(infoBytes, usermem.ByteOrder, &info)
+ binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info)
return time.Duration(info.RTT) * time.Microsecond, time.Duration(info.RTO) * time.Microsecond
}
@@ -407,7 +407,7 @@ func TestRACKWithLostRetransmission(t *testing.T) {
if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want {
t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want)
}
- binary.Unmarshal(infoBytes, usermem.ByteOrder, &info)
+ binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info)
if info.CaState != linux.TCP_CA_Recovery {
t.Fatalf("expected connection to be in fast recovery, want: %v got: %v", linux.TCP_CA_Recovery, info.CaState)
}
diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go
index 3dc8f63ab..1eafe20c3 100644
--- a/test/packetimpact/tests/tcp_retransmits_test.go
+++ b/test/packetimpact/tests/tcp_retransmits_test.go
@@ -23,8 +23,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/test/packetimpact/testbench"
)
@@ -38,7 +38,7 @@ func getRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rto time.Duration)
if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want {
t.Fatalf("unexpected size for TCP_INFO, got %d bytes want %d bytes", got, want)
}
- binary.Unmarshal(infoBytes, usermem.ByteOrder, &info)
+ binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info)
return time.Duration(info.RTO) * time.Microsecond
}
diff --git a/test/packetimpact/tests/tcp_syncookie_test.go b/test/packetimpact/tests/tcp_syncookie_test.go
new file mode 100644
index 000000000..1c21c62ff
--- /dev/null
+++ b/test/packetimpact/tests/tcp_syncookie_test.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_syncookie_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestSynCookie test if the DUT listener is replying back using syn cookies.
+// The test does not complete the handshake by not sending the ACK to SYNACK.
+// When syncookies are not used, this forces the listener to retransmit SYNACK.
+// And when syncookies are being used, there is no such retransmit.
+func TestTCPSynCookie(t *testing.T) {
+ dut := testbench.NewDUT(t)
+
+ // Listening endpoint accepts one more connection than the listen backlog.
+ _, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+
+ var withoutSynCookieConn testbench.TCPIPv4
+ var withSynCookieConn testbench.TCPIPv4
+
+ // Test if the DUT listener replies to more SYNs than listen backlog+1
+ for _, conn := range []*testbench.TCPIPv4{&withoutSynCookieConn, &withSynCookieConn} {
+ *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ }
+ defer withoutSynCookieConn.Close(t)
+ defer withSynCookieConn.Close(t)
+
+ checkSynAck := func(t *testing.T, conn *testbench.TCPIPv4, expectRetransmit bool) {
+ // Expect dut connection to have transitioned to SYN-RCVD state.
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK, but got %s", err)
+ }
+
+ // If the DUT listener is using syn cookies, it will not retransmit SYNACK
+ got, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)), Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, 2*time.Second)
+ if expectRetransmit && err != nil {
+ t.Fatalf("expected retransmitted SYN-ACK, but got %s", err)
+ }
+ if !expectRetransmit && err == nil {
+ t.Fatalf("expected no retransmitted SYN-ACK, but got %s", got)
+ }
+ }
+
+ t.Run("without syncookies", func(t *testing.T) { checkSynAck(t, &withoutSynCookieConn, true /*expectRetransmit*/) })
+ t.Run("with syncookies", func(t *testing.T) { checkSynAck(t, &withSynCookieConn, false /*expectRetransmit*/) })
+}
diff --git a/test/perf/BUILD b/test/perf/BUILD
index ed899ac22..71982fc4d 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -35,7 +35,7 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
debug = False,
tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
@@ -48,7 +48,7 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
debug = False,
tags = ["nogotsan"],
test = "//test/perf/linux:gettid_benchmark",
@@ -106,7 +106,7 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
debug = False,
test = "//test/perf/linux:signal_benchmark",
)
@@ -124,9 +124,10 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
add_overlay = True,
debug = False,
+ tags = ["nogotsan"],
test = "//test/perf/linux:unlink_benchmark",
)
diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc
index db74cb264..047a034bd 100644
--- a/test/perf/linux/getpid_benchmark.cc
+++ b/test/perf/linux/getpid_benchmark.cc
@@ -31,6 +31,24 @@ void BM_Getpid(benchmark::State& state) {
BENCHMARK(BM_Getpid);
+#ifdef __x86_64__
+
+#define SYSNO_STR1(x) #x
+#define SYSNO_STR(x) SYSNO_STR1(x)
+
+// BM_GetpidOpt uses the most often pattern of calling system calls:
+// mov $SYS_XXX, %eax; syscall.
+void BM_GetpidOpt(benchmark::State& state) {
+ for (auto s : state) {
+ __asm__("movl $" SYSNO_STR(SYS_getpid) ", %%eax\n"
+ "syscall\n"
+ : : : "rax", "rcx", "r11");
+ }
+}
+
+BENCHMARK(BM_GetpidOpt);
+#endif // __x86_64__
+
} // namespace
} // namespace testing
diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc
index 7b060c70e..d495f3ddc 100644
--- a/test/perf/linux/write_benchmark.cc
+++ b/test/perf/linux/write_benchmark.cc
@@ -46,6 +46,18 @@ void BM_Write(benchmark::State& state) {
BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime();
+void BM_Append(benchmark::State& state) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY | O_APPEND));
+
+ const char data = 'a';
+ for (auto _ : state) {
+ TEST_CHECK(WriteFd(fd.get(), &data, 1) == 1);
+ }
+}
+
+BENCHMARK(BM_Append);
+
} // namespace
} // namespace testing
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 829247657..2a0ef2cec 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -4,7 +4,7 @@ load("//tools:defs.bzl", "default_platform", "platforms")
def _runner_test_impl(ctx):
# Generate a runner binary.
- runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+ runner = ctx.actions.declare_file(ctx.label.name)
runner_content = "\n".join([
"#!/bin/bash",
"set -euf -x -o pipefail",
@@ -85,18 +85,9 @@ def _syscall_test(
# Add the full_platform and file access in a tag to make it easier to run
# all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
+ tags = list(tags)
tags += [full_platform, "file_" + file_access]
- # Hash this target into one of 15 buckets. This can be used to
- # randomly split targets between different workflows.
- hash15 = hash(native.package_name() + name) % 15
- tags.append("hash15:" + str(hash15))
-
- # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until
- # we figure out how to request ipv4 sockets on Guitar machines.
- if network == "host":
- tags.append("noguitar")
-
# Disable off-host networking.
tags.append("requires-net:loopback")
tags.append("requires-net:ipv4")
@@ -157,116 +148,82 @@ def syscall_test(
if not tags:
tags = []
- vfs2_tags = list(tags)
- if vfs2:
- # Add tag to easily run VFS2 tests with --test_tag_filters=vfs2
- vfs2_tags.append("vfs2")
- if fuse:
- vfs2_tags.append("fuse")
-
- else:
- # Don't automatically run tests tests not yet passing.
- vfs2_tags.append("manual")
- vfs2_tags.append("noguitar")
- vfs2_tags.append("notap")
-
- _syscall_test(
- test = test,
- platform = default_platform,
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = platforms[default_platform] + vfs2_tags,
- debug = debug,
- vfs2 = True,
- fuse = fuse,
- **kwargs
- )
- if fuse:
- # Only generate *_vfs2_fuse target if fuse parameter is enabled.
- return
-
- _syscall_test(
- test = test,
- platform = "native",
- use_tmpfs = False,
- add_uds_tree = add_uds_tree,
- tags = list(tags),
- debug = debug,
- **kwargs
- )
-
- for (platform, platform_tags) in platforms.items():
+ if vfs2 and not fuse:
+ # Generate a vfs1 plain test. Most testing will now be
+ # biased towards vfs2, with only a single vfs1 case.
_syscall_test(
test = test,
- platform = platform,
+ platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
- tags = platform_tags + tags,
+ tags = tags + platforms[default_platform],
debug = debug,
+ vfs2 = False,
**kwargs
)
- if add_overlay:
+ if not fuse:
+ # Generate a native test if fuse is not required.
_syscall_test(
test = test,
- platform = default_platform,
- use_tmpfs = use_tmpfs,
+ platform = "native",
+ use_tmpfs = False,
add_uds_tree = add_uds_tree,
- tags = platforms[default_platform] + tags,
+ tags = tags,
debug = debug,
- overlay = True,
**kwargs
)
- # TODO(gvisor.dev/issue/4407): Remove tags to enable VFS2 overlay tests.
- overlay_vfs2_tags = list(vfs2_tags)
- overlay_vfs2_tags.append("manual")
- overlay_vfs2_tags.append("noguitar")
- overlay_vfs2_tags.append("notap")
+ for (platform, platform_tags) in platforms.items():
_syscall_test(
test = test,
- platform = default_platform,
+ platform = platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
- tags = platforms[default_platform] + overlay_vfs2_tags,
+ tags = platform_tags + tags,
+ fuse = fuse,
+ vfs2 = vfs2,
debug = debug,
- overlay = True,
- vfs2 = True,
**kwargs
)
- if add_hostinet:
+ if add_overlay:
_syscall_test(
test = test,
platform = default_platform,
use_tmpfs = use_tmpfs,
- network = "host",
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
+ fuse = fuse,
+ vfs2 = vfs2,
+ overlay = True,
**kwargs
)
-
- if not use_tmpfs:
- # Also test shared gofer access.
+ if add_hostinet:
_syscall_test(
test = test,
platform = default_platform,
use_tmpfs = use_tmpfs,
+ network = "host",
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
debug = debug,
- file_access = "shared",
+ fuse = fuse,
+ vfs2 = vfs2,
**kwargs
)
+ if not use_tmpfs:
+ # Also test shared gofer access.
_syscall_test(
test = test,
platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
- tags = platforms[default_platform] + vfs2_tags,
+ tags = platforms[default_platform] + tags,
debug = debug,
file_access = "shared",
- vfs2 = True,
+ fuse = fuse,
+ vfs2 = vfs2,
**kwargs
)
diff --git a/test/runner/runner.go b/test/runner/runner.go
index a8a134fe2..d314a5036 100644
--- a/test/runner/runner.go
+++ b/test/runner/runner.go
@@ -252,6 +252,7 @@ func runRunsc(spec *specs.Spec) error {
debugLogDir += "/"
log.Infof("runsc logs: %s", debugLogDir)
args = append(args, "-debug-log", debugLogDir)
+ args = append(args, "-coverage-report", debugLogDir)
// Default -log sends messages to stderr which makes reading the test log
// difficult. Instead, drop them when debug log is enabled given it's a
diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl
index 702522d86..2550b61a3 100644
--- a/test/runtimes/defs.bzl
+++ b/test/runtimes/defs.bzl
@@ -75,7 +75,6 @@ def runtime_test(name, **kwargs):
"local",
"manual",
],
- size = "enormous",
**kwargs
)
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index ef299799e..affcae8fd 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -244,6 +244,10 @@ syscall_test(
)
syscall_test(
+ test = "//test/syscalls/linux:verity_ioctl_test",
+)
+
+syscall_test(
test = "//test/syscalls/linux:iptables_test",
)
@@ -318,6 +322,10 @@ syscall_test(
)
syscall_test(
+ test = "//test/syscalls/linux:verity_mount_test",
+)
+
+syscall_test(
size = "medium",
test = "//test/syscalls/linux:mremap_test",
)
@@ -772,8 +780,7 @@ syscall_test(
)
syscall_test(
- # NOTE(b/116636318): Large sendmsg may stall a long time.
- size = "enormous",
+ flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time.
shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_dgram_local_test",
)
@@ -791,8 +798,7 @@ syscall_test(
)
syscall_test(
- # NOTE(b/116636318): Large sendmsg may stall a long time.
- size = "enormous",
+ flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time.
shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_seqpacket_local_test",
)
@@ -995,3 +1001,7 @@ syscall_test(
syscall_test(
test = "//test/syscalls/linux:processes_test",
)
+
+syscall_test(
+ test = "//test/syscalls/linux:cgroup_test",
+)
diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc
index 3c825477c..6080a59b7 100644
--- a/test/syscalls/linux/32bit.cc
+++ b/test/syscalls/linux/32bit.cc
@@ -22,15 +22,13 @@
#include "test/util/posix_error.h"
#include "test/util/test_util.h"
-#ifndef __x86_64__
-#error "This test is x86-64 specific."
-#endif
-
namespace gvisor {
namespace testing {
namespace {
+#ifdef __x86_64__
+
constexpr char kInt3 = '\xcc';
constexpr char kInt80[2] = {'\xcd', '\x80'};
constexpr char kSyscall[2] = {'\x0f', '\x05'};
@@ -244,5 +242,7 @@ TEST(Call32Bit, Disallowed) {
} // namespace
+#endif
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 043ada583..55f3fc4ae 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -212,10 +212,7 @@ cc_binary(
cc_binary(
name = "32bit_test",
testonly = 1,
- srcs = select_arch(
- amd64 = ["32bit.cc"],
- arm64 = [],
- ),
+ srcs = ["32bit.cc"],
linkstatic = 1,
deps = [
"@com_google_absl//absl/base:core_headers",
@@ -1014,6 +1011,22 @@ cc_binary(
],
)
+cc_binary(
+ name = "verity_ioctl_test",
+ testonly = 1,
+ srcs = ["verity_ioctl.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:mount_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
cc_library(
name = "iptables_types",
testonly = 1,
@@ -1304,6 +1317,20 @@ cc_binary(
)
cc_binary(
+ name = "verity_mount_test",
+ testonly = 1,
+ srcs = ["verity_mount.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "mremap_test",
testonly = 1,
srcs = ["mremap.cc"],
@@ -4205,3 +4232,24 @@ cc_binary(
"//test/util:test_util",
],
)
+
+cc_binary(
+ name = "cgroup_test",
+ testonly = 1,
+ srcs = ["cgroup.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cgroup_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ ],
+)
diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc
index f65a14fb8..fe560cfc5 100644
--- a/test/syscalls/linux/accept_bind.cc
+++ b/test/syscalls/linux/accept_bind.cc
@@ -67,6 +67,42 @@ TEST_P(AllSocketPairTest, ListenDecreaseBacklog) {
SyscallSucceeds());
}
+TEST_P(AllSocketPairTest, ListenBacklogSizes) {
+ DisableSave ds;
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int type;
+ socklen_t typelen = sizeof(type);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen),
+ SyscallSucceeds());
+
+ std::array<int, 3> backlogs = {-1, 0, 1};
+ for (auto& backlog : backlogs) {
+ ASSERT_THAT(listen(sockets->first_fd(), backlog), SyscallSucceeds());
+
+ int expected_accepts = backlog;
+ if (backlog < 0) {
+ expected_accepts = 1024;
+ }
+ for (int i = 0; i < expected_accepts; i++) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ // Connect to the listening socket.
+ const FileDescriptor client =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0));
+ ASSERT_THAT(connect(client.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ const FileDescriptor accepted = ASSERT_NO_ERRNO_AND_VALUE(
+ Accept(sockets->first_fd(), nullptr, nullptr));
+ }
+ }
+}
+
TEST_P(AllSocketPairTest, ListenWithoutBind) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(EINVAL));
@@ -285,8 +321,7 @@ TEST_P(AllSocketPairTest, AcceptValidAddrLen) {
struct sockaddr_un addr = {};
socklen_t addr_len = sizeof(addr);
ASSERT_THAT(
- accepted = accept(sockets->first_fd(),
- reinterpret_cast<struct sockaddr*>(&addr), &addr_len),
+ accepted = accept(sockets->first_fd(), AsSockAddr(&addr), &addr_len),
SyscallSucceeds());
ASSERT_THAT(close(accepted), SyscallSucceeds());
}
@@ -307,8 +342,7 @@ TEST_P(AllSocketPairTest, AcceptNegativeAddrLen) {
// With a negative addr_len, accept returns EINVAL,
struct sockaddr_un addr = {};
socklen_t addr_len = -1;
- ASSERT_THAT(accept(sockets->first_fd(),
- reinterpret_cast<struct sockaddr*>(&addr), &addr_len),
+ ASSERT_THAT(accept(sockets->first_fd(), AsSockAddr(&addr), &addr_len),
SyscallFailsWithErrno(EINVAL));
}
@@ -499,10 +533,9 @@ TEST_P(AllSocketPairTest, UnboundSenderAddr) {
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
- ASSERT_THAT(
- RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&addr), &addr_len),
- SyscallSucceedsWithValue(sizeof(i)));
+ ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ AsSockAddr(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
EXPECT_EQ(addr_len, 0);
}
@@ -534,10 +567,9 @@ TEST_P(AllSocketPairTest, BoundSenderAddr) {
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
- ASSERT_THAT(
- RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&addr), &addr_len),
- SyscallSucceedsWithValue(sizeof(i)));
+ ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ AsSockAddr(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
EXPECT_EQ(addr_len, sockets->second_addr_len());
EXPECT_EQ(
memcmp(&addr, sockets->second_addr(),
@@ -573,10 +605,9 @@ TEST_P(AllSocketPairTest, BindAfterConnectSenderAddr) {
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
- ASSERT_THAT(
- RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&addr), &addr_len),
- SyscallSucceedsWithValue(sizeof(i)));
+ ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ AsSockAddr(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
EXPECT_EQ(addr_len, sockets->second_addr_len());
EXPECT_EQ(
memcmp(&addr, sockets->second_addr(),
@@ -612,10 +643,9 @@ TEST_P(AllSocketPairTest, BindAfterAcceptSenderAddr) {
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
- ASSERT_THAT(
- RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&addr), &addr_len),
- SyscallSucceedsWithValue(sizeof(i)));
+ ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0,
+ AsSockAddr(&addr), &addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
EXPECT_EQ(addr_len, sockets->second_addr_len());
EXPECT_EQ(
memcmp(&addr, sockets->second_addr(),
diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc
index 940c97285..cd0704334 100644
--- a/test/syscalls/linux/alarm.cc
+++ b/test/syscalls/linux/alarm.cc
@@ -36,7 +36,7 @@ void do_nothing_handler(int sig, siginfo_t* siginfo, void* arg) {}
// No random save as the test relies on alarm timing. Cooperative save tests
// already cover the save between alarm and read.
-TEST(AlarmTest, Interrupt_NoRandomSave) {
+TEST(AlarmTest, Interrupt) {
int pipe_fds[2];
ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
@@ -71,7 +71,7 @@ void inc_alarms_handler(int sig, siginfo_t* siginfo, void* arg) {
// No random save as the test relies on alarm timing. Cooperative save tests
// already cover the save between alarm and read.
-TEST(AlarmTest, Restart_NoRandomSave) {
+TEST(AlarmTest, Restart) {
alarms_received = 0;
int pipe_fds[2];
@@ -114,7 +114,7 @@ TEST(AlarmTest, Restart_NoRandomSave) {
// No random save as the test relies on alarm timing. Cooperative save tests
// already cover the save between alarm and pause.
-TEST(AlarmTest, SaSiginfo_NoRandomSave) {
+TEST(AlarmTest, SaSiginfo) {
// Use a signal handler that interrupts but does nothing rather than using the
// default terminate action.
struct sigaction sa;
@@ -134,7 +134,7 @@ TEST(AlarmTest, SaSiginfo_NoRandomSave) {
// No random save as the test relies on alarm timing. Cooperative save tests
// already cover the save between alarm and pause.
-TEST(AlarmTest, SaInterrupt_NoRandomSave) {
+TEST(AlarmTest, SaInterrupt) {
// Use a signal handler that interrupts but does nothing rather than using the
// default terminate action.
struct sigaction sa;
diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc
new file mode 100644
index 000000000..862328f5b
--- /dev/null
+++ b/test/syscalls/linux/cgroup.cc
@@ -0,0 +1,452 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// All tests in this file rely on being about to mount and unmount cgroupfs,
+// which isn't expected to work, or be safe on a general linux system.
+
+#include <limits.h>
+#include <sys/mount.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_split.h"
+#include "test/util/capability_util.h"
+#include "test/util/cgroup_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using ::testing::_;
+using ::testing::Ge;
+using ::testing::Gt;
+
+std::vector<std::string> known_controllers = {
+ "cpu", "cpuset", "cpuacct", "job", "memory",
+};
+
+bool CgroupsAvailable() {
+ return IsRunningOnGvisor() && !IsRunningWithVFS1() &&
+ TEST_CHECK_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN));
+}
+
+TEST(Cgroup, MountSucceeds) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ EXPECT_NO_ERRNO(c.ContainsCallingProcess());
+}
+
+TEST(Cgroup, SeparateMounts) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+
+ for (const auto& ctl : known_controllers) {
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(ctl));
+ EXPECT_NO_ERRNO(c.ContainsCallingProcess());
+ }
+}
+
+TEST(Cgroup, AllControllersImplicit) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ for (const auto& ctl : known_controllers) {
+ EXPECT_TRUE(cgroups_entries.contains(ctl))
+ << absl::StreamFormat("ctl=%s", ctl);
+ }
+ EXPECT_EQ(cgroups_entries.size(), known_controllers.size());
+}
+
+TEST(Cgroup, AllControllersExplicit) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("all"));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ for (const auto& ctl : known_controllers) {
+ EXPECT_TRUE(cgroups_entries.contains(ctl))
+ << absl::StreamFormat("ctl=%s", ctl);
+ }
+ EXPECT_EQ(cgroups_entries.size(), known_controllers.size());
+}
+
+TEST(Cgroup, ProcsAndTasks) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ absl::flat_hash_set<pid_t> pids = ASSERT_NO_ERRNO_AND_VALUE(c.Procs());
+ absl::flat_hash_set<pid_t> tids = ASSERT_NO_ERRNO_AND_VALUE(c.Tasks());
+
+ EXPECT_GE(tids.size(), pids.size()) << "Found more processes than threads";
+
+ // Pids should be a strict subset of tids.
+ for (auto it = pids.begin(); it != pids.end(); ++it) {
+ EXPECT_TRUE(tids.contains(*it))
+ << absl::StreamFormat("Have pid %d, but no such tid", *it);
+ }
+}
+
+TEST(Cgroup, ControllersMustBeInUniqueHierarchy) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ // Hierarchy #1: all controllers.
+ Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ // Hierarchy #2: memory.
+ //
+ // This should conflict since memory is already in hierarchy #1, and the two
+ // hierarchies have different sets of controllers, so this mount can't be a
+ // view into hierarchy #1.
+ EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _))
+ << "Memory controller mounted on two hierarchies";
+ EXPECT_THAT(m.MountCgroupfs("cpu"), PosixErrorIs(EBUSY, _))
+ << "CPU controller mounted on two hierarchies";
+}
+
+TEST(Cgroup, UnmountFreesControllers) {
+ SKIP_IF(!CgroupsAvailable());
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ // All controllers are now attached to all's hierarchy. Attempting new mount
+ // with any individual controller should fail.
+ EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _))
+ << "Memory controller mounted on two hierarchies";
+
+ // Unmount the "all" hierarchy. This should enable any controller to be
+ // mounted on a new hierarchy again.
+ ASSERT_NO_ERRNO(m.Unmount(all));
+ EXPECT_NO_ERRNO(m.MountCgroupfs("memory"));
+ EXPECT_NO_ERRNO(m.MountCgroupfs("cpu"));
+}
+
+TEST(Cgroup, OnlyContainsControllerSpecificFiles) {
+ SKIP_IF(!CgroupsAvailable());
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ EXPECT_THAT(Exists(mem.Relpath("memory.usage_in_bytes")),
+ IsPosixErrorOkAndHolds(true));
+ // CPU files shouldn't exist in memory cgroups.
+ EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_period_us")),
+ IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_quota_us")),
+ IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(mem.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(false));
+
+ Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_period_us")),
+ IsPosixErrorOkAndHolds(true));
+ EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_quota_us")),
+ IsPosixErrorOkAndHolds(true));
+ EXPECT_THAT(Exists(cpu.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(true));
+ // Memory files shouldn't exist in cpu cgroups.
+ EXPECT_THAT(Exists(cpu.Relpath("memory.usage_in_bytes")),
+ IsPosixErrorOkAndHolds(false));
+}
+
+TEST(Cgroup, InvalidController) {
+ SKIP_IF(!CgroupsAvailable());
+
+ TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string mopts = "this-controller-is-invalid";
+ EXPECT_THAT(
+ mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Cgroup, MoptAllMustBeExclusive) {
+ SKIP_IF(!CgroupsAvailable());
+
+ TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string mopts = "all,cpu";
+ EXPECT_THAT(
+ mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(MemoryCgroup, MemoryUsageInBytes) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ EXPECT_THAT(c.ReadIntegerControlFile("memory.usage_in_bytes"),
+ IsPosixErrorOkAndHolds(Gt(0)));
+}
+
+TEST(CPUCgroup, ControlFilesHaveDefaultValues) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_quota_us"),
+ IsPosixErrorOkAndHolds(-1));
+ EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_period_us"),
+ IsPosixErrorOkAndHolds(100000));
+ EXPECT_THAT(c.ReadIntegerControlFile("cpu.shares"),
+ IsPosixErrorOkAndHolds(1024));
+}
+
+TEST(CPUAcctCgroup, CPUAcctUsage) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct"));
+
+ const int64_t usage =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage"));
+ const int64_t usage_user =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_user"));
+ const int64_t usage_sys =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_sys"));
+
+ EXPECT_GE(usage, 0);
+ EXPECT_GE(usage_user, 0);
+ EXPECT_GE(usage_sys, 0);
+
+ EXPECT_GE(usage_user + usage_sys, usage);
+}
+
+TEST(CPUAcctCgroup, CPUAcctStat) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct"));
+
+ std::string stat =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuacct.stat"));
+
+ // We're expecting the contents of "cpuacct.stat" to look similar to this:
+ //
+ // user 377986
+ // system 220662
+
+ std::vector<absl::string_view> lines =
+ absl::StrSplit(stat, '\n', absl::SkipEmpty());
+ ASSERT_EQ(lines.size(), 2);
+
+ std::vector<absl::string_view> user_tokens =
+ StrSplit(lines[0], absl::ByChar(' '));
+ EXPECT_EQ(user_tokens[0], "user");
+ EXPECT_THAT(Atoi<int64_t>(user_tokens[1]), IsPosixErrorOkAndHolds(Ge(0)));
+
+ std::vector<absl::string_view> sys_tokens =
+ StrSplit(lines[1], absl::ByChar(' '));
+ EXPECT_EQ(sys_tokens[0], "system");
+ EXPECT_THAT(Atoi<int64_t>(sys_tokens[1]), IsPosixErrorOkAndHolds(Ge(0)));
+}
+
+// WriteAndVerifyControlValue attempts to write val to a cgroup file at path,
+// and verify the value by reading it afterwards.
+PosixError WriteAndVerifyControlValue(const Cgroup& c, std::string_view path,
+ int64_t val) {
+ RETURN_IF_ERRNO(c.WriteIntegerControlFile(path, val));
+ ASSIGN_OR_RETURN_ERRNO(int64_t newval, c.ReadIntegerControlFile(path));
+ if (newval != val) {
+ return PosixError(
+ EINVAL,
+ absl::StrFormat(
+ "Unexpected value for control file '%s': expected %d, got %d", path,
+ val, newval));
+ }
+ return NoError();
+}
+
+TEST(JobCgroup, ReadWriteRead) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("job"));
+
+ EXPECT_THAT(c.ReadIntegerControlFile("job.id"), IsPosixErrorOkAndHolds(0));
+ EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", 1234));
+ EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", -1));
+ EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", LLONG_MIN));
+ EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", LLONG_MAX));
+}
+
+TEST(ProcCgroups, Empty) {
+ SKIP_IF(!CgroupsAvailable());
+
+ absl::flat_hash_map<std::string, CgroupsEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ // No cgroups mounted yet, we should have no entries.
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcCgroups, ProcCgroupsEntries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+
+ Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ absl::flat_hash_map<std::string, CgroupsEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_EQ(entries.size(), 1);
+ ASSERT_TRUE(entries.contains("memory"));
+ CgroupsEntry mem_e = entries["memory"];
+ EXPECT_EQ(mem_e.subsys_name, "memory");
+ EXPECT_GE(mem_e.hierarchy, 1);
+ // Expect a single root cgroup.
+ EXPECT_EQ(mem_e.num_cgroups, 1);
+ // Cgroups are currently always enabled when mounted.
+ EXPECT_TRUE(mem_e.enabled);
+
+ // Add a second cgroup, and check for new entry.
+
+ Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_EQ(entries.size(), 2);
+ EXPECT_TRUE(entries.contains("memory")); // Still have memory entry.
+ ASSERT_TRUE(entries.contains("cpu"));
+ CgroupsEntry cpu_e = entries["cpu"];
+ EXPECT_EQ(cpu_e.subsys_name, "cpu");
+ EXPECT_GE(cpu_e.hierarchy, 1);
+ EXPECT_EQ(cpu_e.num_cgroups, 1);
+ EXPECT_TRUE(cpu_e.enabled);
+
+ // Separate hierarchies, since controllers were mounted separately.
+ EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy);
+}
+
+TEST(ProcCgroups, UnmountRemovesEntries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup cg = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu,memory"));
+ absl::flat_hash_map<std::string, CgroupsEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_EQ(entries.size(), 2);
+
+ ASSERT_NO_ERRNO(m.Unmount(cg));
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcPIDCgroup, Empty) {
+ SKIP_IF(!CgroupsAvailable());
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcPIDCgroup, Entries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_EQ(entries.size(), 1);
+ PIDCgroupEntry mem_e = entries["memory"];
+ EXPECT_GE(mem_e.hierarchy, 1);
+ EXPECT_EQ(mem_e.controllers, "memory");
+ EXPECT_EQ(mem_e.path, "/");
+
+ Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_EQ(entries.size(), 2);
+ EXPECT_TRUE(entries.contains("memory")); // Still have memory entry.
+ PIDCgroupEntry cpu_e = entries["cpu"];
+ EXPECT_GE(cpu_e.hierarchy, 1);
+ EXPECT_EQ(cpu_e.controllers, "cpu");
+ EXPECT_EQ(cpu_e.path, "/");
+
+ // Separate hierarchies, since controllers were mounted separately.
+ EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy);
+}
+
+TEST(ProcPIDCgroup, UnmountRemovesEntries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_GT(entries.size(), 0);
+
+ ASSERT_NO_ERRNO(m.Unmount(all));
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcCgroup, PIDCgroupMatchesCgroups) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+
+ CgroupsEntry cgroup_mem = cgroups_entries["memory"];
+ PIDCgroupEntry pid_mem = pid_entries["memory"];
+
+ EXPECT_EQ(cgroup_mem.hierarchy, pid_mem.hierarchy);
+
+ CgroupsEntry cgroup_cpu = cgroups_entries["cpu"];
+ PIDCgroupEntry pid_cpu = pid_entries["cpu"];
+
+ EXPECT_EQ(cgroup_cpu.hierarchy, pid_cpu.hierarchy);
+ EXPECT_NE(cgroup_mem.hierarchy, cgroup_cpu.hierarchy);
+ EXPECT_NE(pid_mem.hierarchy, pid_cpu.hierarchy);
+}
+
+TEST(ProcCgroup, MultiControllerHierarchy) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory,cpu"));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+
+ CgroupsEntry mem_e = cgroups_entries["memory"];
+ CgroupsEntry cpu_e = cgroups_entries["cpu"];
+
+ // Both controllers should have the same hierarchy ID.
+ EXPECT_EQ(mem_e.hierarchy, cpu_e.hierarchy);
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+
+ // Expecting an entry listing both controllers, that matches the previous
+ // hierarchy ID. Note that the controllers are listed in alphabetical order.
+ PIDCgroupEntry pid_e = pid_entries["cpu,memory"];
+ EXPECT_EQ(pid_e.hierarchy, mem_e.hierarchy);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc
index 8233df0f8..4a5ea84d4 100644
--- a/test/syscalls/linux/chmod.cc
+++ b/test/syscalls/linux/chmod.cc
@@ -53,7 +53,7 @@ TEST(ChmodTest, ChmodDirSucceeds) {
EXPECT_THAT(open(fileInDir.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
}
-TEST(ChmodTest, FchmodFileSucceeds_NoRandomSave) {
+TEST(ChmodTest, FchmodFileSucceeds) {
// Drop capabilities that allow us to file directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
@@ -70,7 +70,7 @@ TEST(ChmodTest, FchmodFileSucceeds_NoRandomSave) {
EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES));
}
-TEST(ChmodTest, FchmodDirSucceeds_NoRandomSave) {
+TEST(ChmodTest, FchmodDirSucceeds) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
@@ -212,7 +212,7 @@ TEST(ChmodTest, FchmodatDir) {
SyscallFailsWithErrno(EACCES));
}
-TEST(ChmodTest, ChmodDowngradeWritability_NoRandomSave) {
+TEST(ChmodTest, ChmodDowngradeWritability) {
auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666));
int fd;
@@ -238,7 +238,7 @@ TEST(ChmodTest, ChmodFileToNoPermissionsSucceeds) {
SyscallFailsWithErrno(EACCES));
}
-TEST(ChmodTest, FchmodDowngradeWritability_NoRandomSave) {
+TEST(ChmodTest, FchmodDowngradeWritability) {
auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
int fd;
@@ -252,7 +252,7 @@ TEST(ChmodTest, FchmodDowngradeWritability_NoRandomSave) {
EXPECT_THAT(close(fd), SyscallSucceeds());
}
-TEST(ChmodTest, FchmodFileToNoPermissionsSucceeds_NoRandomSave) {
+TEST(ChmodTest, FchmodFileToNoPermissionsSucceeds) {
// Drop capabilities that allow us to override file permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
index 1d0d584cd..32860aa21 100644
--- a/test/syscalls/linux/dev.cc
+++ b/test/syscalls/linux/dev.cc
@@ -117,7 +117,7 @@ TEST(DevTest, ReadDevNull) {
}
// Do not allow random save as it could lead to partial reads.
-TEST(DevTest, ReadDevZero_NoRandomSave) {
+TEST(DevTest, ReadDevZero) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY));
diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc
index 8a72ef10a..b180f633c 100644
--- a/test/syscalls/linux/epoll.cc
+++ b/test/syscalls/linux/epoll.cc
@@ -115,7 +115,7 @@ TEST(EpollTest, LastNonWritable) {
}
}
-TEST(EpollTest, Timeout_NoRandomSave) {
+TEST(EpollTest, Timeout) {
auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
std::vector<FileDescriptor> eventfds;
for (int i = 0; i < kFDsPerEpoll; i++) {
@@ -290,7 +290,7 @@ TEST(EpollTest, Oneshot) {
SyscallSucceedsWithValue(0));
}
-TEST(EpollTest, EdgeTriggered_NoRandomSave) {
+TEST(EpollTest, EdgeTriggered) {
// Test edge-triggered entry: make it edge-triggered, first wait should
// return it, second one should time out, make it writable again, third wait
// should return it, fourth wait should timeout.
diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc
index dc794415e..8202d35fa 100644
--- a/test/syscalls/linux/eventfd.cc
+++ b/test/syscalls/linux/eventfd.cc
@@ -175,7 +175,7 @@ TEST(EventfdTest, SpliceFromPipePartialSucceeds) {
}
// NotifyNonZero is inherently racy, so random save is disabled.
-TEST(EventfdTest, NotifyNonZero_NoRandomSave) {
+TEST(EventfdTest, NotifyNonZero) {
// Waits will time out at 10 seconds.
constexpr int kEpollTimeoutMs = 10000;
// Create an eventfd descriptor.
diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc
index b286e84fe..fd387aa45 100644
--- a/test/syscalls/linux/flock.cc
+++ b/test/syscalls/linux/flock.cc
@@ -205,7 +205,7 @@ TEST_F(FlockTest, TestSharedLockFailExclusiveHolderNonblocking) {
void trivial_handler(int signum) {}
-TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking_NoRandomSave) {
+TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking) {
const DisableSave ds; // Timing-related.
// This test will verify that a shared lock is denied while
@@ -262,7 +262,7 @@ TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderNonblocking) {
ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
}
-TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking_NoRandomSave) {
+TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking) {
const DisableSave ds; // Timing-related.
// This test will verify that an exclusive lock is denied while
@@ -499,7 +499,7 @@ TEST_F(FlockTest, TestDupFdFollowedByLock) {
// NOTE: These blocking tests are not perfect. Unfortunately it's very hard to
// determine if a thread was actually blocked in the kernel so we're forced
// to use timing.
-TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks_NoRandomSave) {
+TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks) {
// This test will verify that although LOCK_NB isn't specified
// two different fds can obtain shared locks without blocking.
ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds());
@@ -539,7 +539,7 @@ TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks_NoRandomSave) {
EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
}
-TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive_NoRandomSave) {
+TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive) {
// This test will verify that if someone holds a shared lock any attempt to
// obtain an exclusive lock will result in blocking.
ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds());
@@ -576,7 +576,7 @@ TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive_NoRandomSave) {
EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
}
-TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared_NoRandomSave) {
+TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared) {
// This test will verify that if someone holds an exclusive lock any attempt
// to obtain a shared lock will result in blocking.
ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds());
@@ -613,7 +613,7 @@ TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared_NoRandomSave) {
EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
}
-TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) {
+TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive) {
// This test will verify that if someone holds an exclusive lock any attempt
// to obtain another exclusive lock will result in blocking.
ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds());
diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc
index c47567b4e..79b0596c4 100644
--- a/test/syscalls/linux/fpsig_fork.cc
+++ b/test/syscalls/linux/fpsig_fork.cc
@@ -44,6 +44,8 @@ namespace {
#define SET_FP0(var) SET_FPREG(var, d0)
#endif
+#define DEFAULT_MXCSR 0x1f80
+
int parent, child;
void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
@@ -57,6 +59,12 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
uint64_t got;
GET_FP0(got);
TEST_CHECK_MSG(val == got, "Basic FP check failed in sigusr1()");
+
+#ifdef __x86_64
+ uint32_t mxcsr;
+ __asm__("STMXCSR %0" : "=m"(mxcsr));
+ TEST_CHECK_MSG(mxcsr == DEFAULT_MXCSR, "Unexpected mxcsr");
+#endif
}
TEST(FPSigTest, Fork) {
@@ -125,6 +133,55 @@ TEST(FPSigTest, Fork) {
}
}
+#ifdef __x86_64__
+TEST(FPSigTest, ForkWithZeroMxcsr) {
+ parent = getpid();
+ pid_t parent_tid = gettid();
+
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ sa.sa_sigaction = sigusr1;
+ ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
+
+ // The control bits of the MXCSR register are callee-saved (preserved across
+ // calls), while the status bits are caller-saved (not preserved).
+ uint32_t expected = 0, origin;
+ __asm__("STMXCSR %0" : "=m"(origin));
+ __asm__("LDMXCSR %0" : : "m"(expected));
+
+ asm volatile(
+ "movl %[killnr], %%eax;"
+ "movl %[parent], %%edi;"
+ "movl %[tid], %%esi;"
+ "movl %[sig], %%edx;"
+ "syscall;"
+ :
+ : [killnr] "i"(__NR_tgkill), [parent] "rm"(parent),
+ [tid] "rm"(parent_tid), [sig] "i"(SIGUSR1)
+ : "rax", "rdi", "rsi", "rdx",
+ // Clobbered by syscall.
+ "rcx", "r11");
+
+ uint32_t got;
+ __asm__("STMXCSR %0" : "=m"(got));
+ __asm__("LDMXCSR %0" : : "m"(origin));
+
+ if (getpid() == parent) { // Parent.
+ int status;
+ ASSERT_THAT(waitpid(child, &status, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ }
+
+ // TEST_CHECK_MSG since this may run in the child.
+ TEST_CHECK_MSG(expected == got, "Bad mxcsr value");
+
+ if (getpid() != parent) { // Child.
+ _exit(0);
+ }
+}
+#endif
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc
index 90b1f0508..859f92b75 100644
--- a/test/syscalls/linux/futex.cc
+++ b/test/syscalls/linux/futex.cc
@@ -220,7 +220,7 @@ TEST_P(PrivateAndSharedFutexTest, Wait_ZeroBitset) {
SyscallFailsWithErrno(EINVAL));
}
-TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, Wake1) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -240,7 +240,7 @@ TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) {
EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1));
}
-TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, Wake0) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -261,7 +261,7 @@ TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) {
EXPECT_THAT(futex_wake(IsPrivate(), &a, 0), SyscallSucceedsWithValue(1));
}
-TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WakeAll) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -282,7 +282,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) {
SyscallSucceedsWithValue(kThreads));
}
-TEST_P(PrivateAndSharedFutexTest, WakeSome_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WakeSome) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -331,7 +331,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeSome_NoRandomSave) {
EXPECT_EQ(timedout, kThreads - kWokenThreads);
}
-TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -346,7 +346,7 @@ TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake_NoRandomSave) {
EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1));
}
-TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -361,7 +361,7 @@ TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset_NoRandomSave) {
SyscallSucceedsWithValue(1));
}
-TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -379,7 +379,7 @@ TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch_NoRandomSave) {
SyscallSucceedsWithValue(1));
}
-TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -401,7 +401,7 @@ TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch_NoRandomSave) {
SyscallSucceedsWithValue(0));
}
-TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue);
@@ -428,7 +428,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess_NoRandomSave) {
EXPECT_EQ(b, kInitialValue + 2);
}
-TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue);
@@ -457,7 +457,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure_NoRandomSave) {
EXPECT_EQ(b, kInitialValue + 2);
}
-TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon) {
auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr());
@@ -484,7 +484,7 @@ TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon_NoRandomSave) {
<< " status " << status;
}
-TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak) {
// Use a futex on a non-stack mapping so we can be sure that the child process
// below isn't the one that breaks copy-on-write.
auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
@@ -520,7 +520,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak_NoRandomSave) {
EXPECT_THAT(futex_wake(IsPrivate(), ptr, 1), SyscallSucceedsWithValue(1));
}
-TEST_P(PrivateAndSharedFutexTest, WakeWrongKind_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, WakeWrongKind) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -584,7 +584,7 @@ TEST(PrivateFutexTest, WakeOp0Xor) {
EXPECT_EQ(a, 0b0110);
}
-TEST(SharedFutexTest, WakeInterprocessSharedAnon_NoRandomSave) {
+TEST(SharedFutexTest, WakeInterprocessSharedAnon) {
auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED));
auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr());
@@ -615,7 +615,7 @@ TEST(SharedFutexTest, WakeInterprocessSharedAnon_NoRandomSave) {
<< " status " << status;
}
-TEST(SharedFutexTest, WakeInterprocessFile_NoRandomSave) {
+TEST(SharedFutexTest, WakeInterprocessFile) {
auto const file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
ASSERT_THAT(truncate(file.path().c_str(), kPageSize), SyscallSucceeds());
auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
@@ -661,7 +661,7 @@ TEST_P(PrivateAndSharedFutexTest, PIBasic) {
EXPECT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallFailsWithErrno(EPERM));
}
-TEST_P(PrivateAndSharedFutexTest, PIConcurrency_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, PIConcurrency) {
DisableSave ds; // Too many syscalls.
std::atomic<int> a = ATOMIC_VAR_INIT(0);
@@ -717,7 +717,7 @@ TEST_P(PrivateAndSharedFutexTest, PITryLock) {
ASSERT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallSucceeds());
}
-TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) {
+TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency) {
DisableSave ds; // Too many syscalls.
std::atomic<int> a = ATOMIC_VAR_INIT(0);
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index a88c89e20..f6b78989b 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -1156,7 +1156,7 @@ TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) {
EXPECT_TRUE(events.empty());
}
-TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) {
+TEST(Inotify, ChmodGeneratesAttribEvent) {
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const TempPath file1 =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
@@ -1999,7 +1999,7 @@ TEST(Inotify, Exec) {
//
// We need to disable S/R because there are filesystems where we cannot re-open
// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
-TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) {
+TEST(Inotify, IncludeUnlinkedFile) {
const DisableSave ds;
const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
@@ -2052,7 +2052,7 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) {
//
// We need to disable S/R because there are filesystems where we cannot re-open
// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
-TEST(Inotify, ExcludeUnlink_NoRandomSave) {
+TEST(Inotify, ExcludeUnlink) {
const DisableSave ds;
// TODO(gvisor.dev/issue/1624): This test fails on VFS1.
SKIP_IF(IsRunningWithVFS1());
@@ -2093,7 +2093,7 @@ TEST(Inotify, ExcludeUnlink_NoRandomSave) {
// We need to disable S/R because there are filesystems where we cannot re-open
// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
-TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) {
+TEST(Inotify, ExcludeUnlinkDirectory) {
// TODO(gvisor.dev/issue/1624): This test fails on VFS1. Remove once VFS1 is
// deleted.
SKIP_IF(IsRunningWithVFS1());
@@ -2138,7 +2138,7 @@ TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) {
//
// We need to disable S/R because there are filesystems where we cannot re-open
// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
-TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) {
+TEST(Inotify, ExcludeUnlinkMultipleChildren) {
// Inotify does not work properly with hard links in gofer and overlay fs.
SKIP_IF(IsRunningOnGvisor() &&
!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir())));
@@ -2184,7 +2184,7 @@ TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) {
//
// We need to disable S/R because there are filesystems where we cannot re-open
// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
-TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) {
+TEST(Inotify, ExcludeUnlinkInodeEvents) {
// TODO(gvisor.dev/issue/1624): Fails on VFS1.
SKIP_IF(IsRunningWithVFS1());
@@ -2284,7 +2284,7 @@ TEST(Inotify, OneShot) {
// This test helps verify that the lock order of filesystem and inotify locks
// is respected when inotify instances and watch targets are concurrently being
// destroyed.
-TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) {
+TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock) {
const DisableSave ds; // Too many syscalls.
// A file descriptor protected by a mutex. This ensures that while a
@@ -2350,7 +2350,7 @@ TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) {
// This test helps verify that the lock order of filesystem and inotify locks
// is respected when adding/removing watches occurs concurrently with the
// removal of their targets.
-TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) {
+TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock) {
const DisableSave ds; // Too many syscalls.
// Set up inotify instances.
@@ -2405,7 +2405,7 @@ TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) {
// This test helps verify that the lock order of filesystem and inotify locks
// is respected when many inotify events and filesystem operations occur
// simultaneously.
-TEST(InotifyTest, NotifyNoDeadlock_NoRandomSave) {
+TEST(InotifyTest, NotifyNoDeadlock) {
const DisableSave ds; // Too many syscalls.
const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index e397d5f57..ac113e6da 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -215,7 +215,7 @@ int TestSIGALRMToMainThread() {
// Random save/restore is disabled as it introduces additional latency and
// unpredictable distribution patterns.
-TEST(ItimerTest, DeliversSIGALRMToMainThread_NoRandomSave) {
+TEST(ItimerTest, DeliversSIGALRMToMainThread) {
pid_t child;
int execve_errno;
auto kill = ASSERT_NO_ERRNO_AND_VALUE(
@@ -266,7 +266,7 @@ int TestSIGPROFFairness(absl::Duration sleep) {
// Random save/restore is disabled as it introduces additional latency and
// unpredictable distribution patterns.
-TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) {
+TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive) {
// On the KVM and ptrace platforms, switches between sentry and application
// context are sometimes extremely slow, causing the itimer to send SIGPROF to
// a thread that either already has one pending or has had SIGPROF delivered,
@@ -301,7 +301,7 @@ TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) {
// Random save/restore is disabled as it introduces additional latency and
// unpredictable distribution patterns.
-TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) {
+TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle) {
// See comment in DeliversSIGPROFToThreadsRoughlyFairlyActive.
const auto gvisor_platform = GvisorPlatform();
SKIP_IF(gvisor_platform == Platform::kKVM ||
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index e65ffee8f..4697c404c 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -431,7 +431,7 @@ TEST_F(OpenTest, CanTruncateReadOnly) {
// If we don't have read permission on the file, opening with
// O_TRUNC should fail.
-TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission_NoRandomSave) {
+TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission) {
// Drop capabilities that allow us to override file permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
@@ -452,7 +452,7 @@ TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission_NoRandomSave) {
// If we don't have read permission but have write permission, opening O_WRONLY
// and O_TRUNC should succeed.
-TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) {
+TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission) {
const DisableSave ds; // Permissions are dropped.
EXPECT_THAT(fchmod(test_file_fd_.get(), S_IWUSR | S_IWGRP),
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 46f41de50..43d446926 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -52,7 +52,7 @@ TEST(CreateTest, CreateAtFile) {
EXPECT_THAT(close(fd), SyscallSucceeds());
}
-TEST(CreateTest, HonorsUmask_NoRandomSave) {
+TEST(CreateTest, HonorsUmask) {
const DisableSave ds; // file cannot be re-opened as writable.
auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
TempUmask mask(0222);
@@ -119,7 +119,7 @@ TEST(CreateTest, OpenCreateROThenRW) {
EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1));
}
-TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) {
+TEST(CreateTest, ChmodReadToWriteBetweenOpens) {
// Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
// override file read/write permissions. CAP_DAC_READ_SEARCH needs to be
// cleared for the same reason.
@@ -149,7 +149,7 @@ TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) {
EXPECT_EQ(c, 'x');
}
-TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) {
+TEST(CreateTest, ChmodWriteToReadBetweenOpens) {
// Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
// override file read/write permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
@@ -177,7 +177,7 @@ TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) {
EXPECT_EQ(c, 'x');
}
-TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) {
+TEST(CreateTest, CreateWithReadFlagNotAllowedByMode) {
// The only time we can open a file with flags forbidden by its permissions
// is when we are creating the file. We cannot re-open with the same flags,
// so we cannot restore an fd obtained from such an operation.
@@ -204,7 +204,7 @@ TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) {
EXPECT_EQ(c, 'x');
}
-TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode_NoRandomSave) {
+TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode) {
// The only time we can open a file with flags forbidden by its permissions
// is when we are creating the file. We cannot re-open with the same flags,
// so we cannot restore an fd obtained from such an operation.
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index d25be0e30..72080a272 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -440,11 +440,7 @@ TEST_P(RawPacketTest, SetSocketRecvBuf) {
ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
SyscallSucceeds());
- // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
- // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior.
- if (!IsRunningOnGvisor()) {
- quarter_sz *= 2;
- }
+ quarter_sz *= 2;
ASSERT_EQ(quarter_sz, val);
}
diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc
index 13afa0eaf..223ddc0c8 100644
--- a/test/syscalls/linux/partial_bad_buffer.cc
+++ b/test/syscalls/linux/partial_bad_buffer.cc
@@ -320,7 +320,7 @@ PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
// EFAULT. It also verifies that passing a buffer which is made up of 2
// pages one valid and one guard page succeeds as long as the write is
// for exactly the size of 1 page.
-TEST_F(PartialBadBufferTest, SendMsgTCP_NoRandomSave) {
+TEST_F(PartialBadBufferTest, SendMsgTCP) {
// FIXME(b/171436815): Netstack save/restore is broken.
const DisableSave ds;
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
index 999c8ab6b..8b78e4b16 100644
--- a/test/syscalls/linux/ping_socket.cc
+++ b/test/syscalls/linux/ping_socket.cc
@@ -35,7 +35,7 @@ namespace {
//
// We disable both random/cooperative S/R for this test as it makes way too many
// syscalls.
-TEST(PingSocket, ICMPPortExhaustion_NoRandomSave) {
+TEST(PingSocket, ICMPPortExhaustion) {
DisableSave ds;
{
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index 01ccbdcd2..96c454485 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -399,7 +399,7 @@ TEST_P(PipeTest, BlockPartialWriteClosed) {
t.Join();
}
-TEST_P(PipeTest, ReadFromClosedFd_NoRandomSave) {
+TEST_P(PipeTest, ReadFromClosedFd) {
SKIP_IF(!CreateBlocking());
absl::Notification notify;
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index 6f9a9498c..5ce7e8c8d 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -57,7 +57,7 @@ TEST_F(PollTest, ZeroTimeout) {
// If random S/R interrupts the poll, SIGALRM may be delivered before poll
// restarts, causing the poll to hang forever.
-TEST_F(PollTest, NegativeTimeout_NoRandomSave) {
+TEST_F(PollTest, NegativeTimeout) {
// Negative timeout mean wait forever so set a timer.
SetTimer(absl::Milliseconds(100));
EXPECT_THAT(poll(nullptr, 0, -1), SyscallFailsWithErrno(EINTR));
diff --git a/test/syscalls/linux/ppoll.cc b/test/syscalls/linux/ppoll.cc
index 8245a11e8..7f7d69731 100644
--- a/test/syscalls/linux/ppoll.cc
+++ b/test/syscalls/linux/ppoll.cc
@@ -76,7 +76,7 @@ TEST_F(PpollTest, ZeroTimeout) {
// If random S/R interrupts the ppoll, SIGALRM may be delivered before ppoll
// restarts, causing the ppoll to hang forever.
-TEST_F(PpollTest, NoTimeout_NoRandomSave) {
+TEST_F(PpollTest, NoTimeout) {
// When there's no timeout, ppoll may never return so set a timer.
SetTimer(absl::Milliseconds(100));
// See that we get interrupted by the timer.
diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc
index c74990ba1..0a09259a3 100644
--- a/test/syscalls/linux/pread64.cc
+++ b/test/syscalls/linux/pread64.cc
@@ -144,7 +144,7 @@ TEST_F(Pread64Test, Overflow) {
SyscallFailsWithErrno(EINVAL));
}
-TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) {
+TEST(Pread64TestNoTempFile, CantReadSocketPair) {
int sock_fds[2];
EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds());
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 493042dfc..6b055ea89 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -1629,7 +1629,7 @@ TEST(ProcPidStatusTest, StateRunning) {
IsPosixErrorOkAndHolds(Contains(Pair("State", "R (running)"))));
}
-TEST(ProcPidStatusTest, StateSleeping_NoRandomSave) {
+TEST(ProcPidStatusTest, StateSleeping) {
// Starts a child process that blocks and checks that State is sleeping.
auto res = WithSubprocess(
[&](int pid) -> PosixError {
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 20f1dc305..04fecc02e 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -189,7 +189,7 @@ PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp,
EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp));
}
-TEST(ProcNetSnmp, TcpReset_NoRandomSave) {
+TEST(ProcNetSnmp, TcpReset) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
DisableSave ds;
@@ -231,7 +231,7 @@ TEST(ProcNetSnmp, TcpReset_NoRandomSave) {
EXPECT_EQ(oldAttemptFails, newAttemptFails - 1);
}
-TEST(ProcNetSnmp, TcpEstab_NoRandomSave) {
+TEST(ProcNetSnmp, TcpEstab) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
DisableSave ds;
@@ -263,9 +263,8 @@ TEST(ProcNetSnmp, TcpEstab_NoRandomSave) {
// Get the port bound by the listening socket.
socklen_t addrlen = sizeof(sin);
- ASSERT_THAT(
- getsockname(s_listen.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(getsockname(s_listen.get(), AsSockAddr(&sin), &addrlen),
+ SyscallSucceeds());
FileDescriptor s_connect =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
@@ -326,7 +325,7 @@ TEST(ProcNetSnmp, TcpEstab_NoRandomSave) {
EXPECT_EQ(oldEstabResets, newEstabResets - 2);
}
-TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
+TEST(ProcNetSnmp, UdpNoPorts) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
DisableSave ds;
@@ -360,7 +359,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
EXPECT_EQ(oldNoPorts, newNoPorts - 1);
}
-TEST(ProcNetSnmp, UdpIn_NoRandomSave) {
+TEST(ProcNetSnmp, UdpIn) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
const DisableSave ds;
@@ -384,9 +383,8 @@ TEST(ProcNetSnmp, UdpIn_NoRandomSave) {
SyscallSucceeds());
// Get the port bound by the server socket.
socklen_t addrlen = sizeof(sin);
- ASSERT_THAT(
- getsockname(server.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(getsockname(server.get(), AsSockAddr(&sin), &addrlen),
+ SyscallSucceeds());
FileDescriptor client =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -421,14 +419,14 @@ TEST(ProcNetSnmp, CheckNetStat) {
int name_count = 0;
int value_count = 0;
std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
- for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) {
+ for (size_t i = 0; i + 1 < lines.size(); i += 2) {
std::vector<absl::string_view> names =
absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
std::vector<absl::string_view> values =
absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
<< "' and '" << lines[i + 1] << "'";
- for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) {
+ for (size_t j = 0; j < names.size() && j < values.size(); ++j) {
if (names[j] == "TCPOrigDataSent" || names[j] == "TCPSynRetrans" ||
names[j] == "TCPDSACKRecv" || names[j] == "TCPDSACKOfoRecv") {
++name_count;
@@ -458,14 +456,14 @@ TEST(ProcNetSnmp, CheckSnmp) {
int name_count = 0;
int value_count = 0;
std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
- for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) {
+ for (size_t i = 0; i + 1 < lines.size(); i += 2) {
std::vector<absl::string_view> names =
absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
std::vector<absl::string_view> values =
absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
<< "' and '" << lines[i + 1] << "'";
- for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) {
+ for (size_t j = 0; j < names.size() && j < values.size(); ++j) {
if (names[j] == "RetransSegs") {
++name_count;
int64_t val;
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
index d61d94309..f7ff65aad 100644
--- a/test/syscalls/linux/proc_net_unix.cc
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -182,7 +182,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
// Returns true on match, and sets 'match' to point to the matching entry.
bool FindBy(std::vector<UnixEntry> entries, UnixEntry* match,
std::function<bool(const UnixEntry&)> predicate) {
- for (long unsigned int i = 0; i < entries.size(); ++i) {
+ for (size_t i = 0; i < entries.size(); ++i) {
if (predicate(entries[i])) {
*match = entries[i];
return true;
@@ -201,15 +201,8 @@ TEST(ProcNetUnix, Exists) {
const std::string content =
ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/unix"));
const std::string header_line = StrCat(kProcNetUnixHeader, "\n");
- if (IsRunningOnGvisor()) {
- // Should be just the header since we don't have any unix domain sockets
- // yet.
- EXPECT_EQ(content, header_line);
- } else {
- // However, on a general linux machine, we could have abitrary sockets on
- // the system, so just check the header.
- EXPECT_THAT(content, ::testing::StartsWith(header_line));
- }
+ // We could have abitrary sockets on the system, so just check the header.
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
}
TEST(ProcNetUnix, FilesystemBindAcceptConnect) {
@@ -223,9 +216,6 @@ TEST(ProcNetUnix, FilesystemBindAcceptConnect) {
std::vector<UnixEntry> entries =
ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
- if (IsRunningOnGvisor()) {
- EXPECT_EQ(entries.size(), 2);
- }
// The server-side socket's path is listed in the socket entry...
UnixEntry s1;
@@ -247,9 +237,6 @@ TEST(ProcNetUnix, AbstractBindAcceptConnect) {
std::vector<UnixEntry> entries =
ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
- if (IsRunningOnGvisor()) {
- EXPECT_EQ(entries.size(), 2);
- }
// The server-side socket's path is listed in the socket entry...
UnixEntry s1;
@@ -261,20 +248,12 @@ TEST(ProcNetUnix, AbstractBindAcceptConnect) {
}
TEST(ProcNetUnix, SocketPair) {
- // Under gvisor, ensure a socketpair() syscall creates exactly 2 new
- // entries. We have no way to verify this under Linux, as we have no control
- // over socket creation on a general Linux machine.
- SKIP_IF(!IsRunningOnGvisor());
-
- std::vector<UnixEntry> entries =
- ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
- ASSERT_EQ(entries.size(), 0);
-
auto sockets =
ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_STREAM).Create());
- entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
- EXPECT_EQ(entries.size(), 2);
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ EXPECT_GE(entries.size(), 2);
}
TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) {
@@ -368,25 +347,12 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create());
- std::vector<UnixEntry> entries =
- ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
-
- // On gVisor, the only two UDS on the system are the ones we just created and
- // we rely on this to locate the test socket entries in the remainder of the
- // test. On a generic Linux system, we have no easy way to locate the
- // corresponding entries, as they don't have an address yet.
- if (IsRunningOnGvisor()) {
- ASSERT_EQ(entries.size(), 2);
- for (const auto& e : entries) {
- ASSERT_EQ(e.state, SS_DISCONNECTING);
- }
- }
-
ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
sockets->first_addr_size()),
SyscallSucceeds());
- entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
const std::string address = ExtractPath(sockets->first_addr());
UnixEntry bind_entry;
ASSERT_TRUE(FindByPath(entries, &bind_entry, address));
@@ -397,25 +363,12 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(
AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create());
- std::vector<UnixEntry> entries =
- ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
-
- // On gVisor, the only two UDS on the system are the ones we just created and
- // we rely on this to locate the test socket entries in the remainder of the
- // test. On a generic Linux system, we have no easy way to locate the
- // corresponding entries, as they don't have an address yet.
- if (IsRunningOnGvisor()) {
- ASSERT_EQ(entries.size(), 2);
- for (const auto& e : entries) {
- ASSERT_EQ(e.state, SS_DISCONNECTING);
- }
- }
-
ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
sockets->first_addr_size()),
SyscallSucceeds());
- entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
+ std::vector<UnixEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
const std::string address = ExtractPath(sockets->first_addr());
UnixEntry bind_entry;
ASSERT_TRUE(FindByPath(entries, &bind_entry, address));
@@ -423,22 +376,6 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) {
ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
sockets->first_addr_size()),
SyscallSucceeds());
-
- entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
-
- // Once again, we have no easy way to identify the connecting socket as it has
- // no listed address. We can only identify the entry as the "non-bind socket
- // entry" on gVisor, where we're guaranteed to have only the two entries we
- // create during this test.
- if (IsRunningOnGvisor()) {
- ASSERT_EQ(entries.size(), 2);
- UnixEntry connect_entry;
- ASSERT_TRUE(
- FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) {
- return e.inode != bind_entry.inode;
- }));
- EXPECT_EQ(connect_entry.state, SS_CONNECTING);
- }
}
} // namespace
diff --git a/test/syscalls/linux/proc_pid_uid_gid_map.cc b/test/syscalls/linux/proc_pid_uid_gid_map.cc
index af052a63c..c030592c8 100644
--- a/test/syscalls/linux/proc_pid_uid_gid_map.cc
+++ b/test/syscalls/linux/proc_pid_uid_gid_map.cc
@@ -203,8 +203,9 @@ TEST_P(ProcSelfUidGidMapTest, IdentityMapOwnID) {
EXPECT_THAT(
InNewUserNamespaceWithMapFD([&](int fd) {
DenySelfSetgroups();
- TEST_PCHECK(static_cast<long unsigned int>(
- write(fd, line.c_str(), line.size())) == line.size());
+ size_t n;
+ TEST_PCHECK((n = write(fd, line.c_str(), line.size())) != -1);
+ TEST_CHECK(n == line.size());
}),
IsPosixErrorOkAndHolds(0));
}
@@ -221,8 +222,9 @@ TEST_P(ProcSelfUidGidMapTest, TrailingNewlineAndNULIgnored) {
DenySelfSetgroups();
// The write should return the full size of the write, even though
// characters after the NUL were ignored.
- TEST_PCHECK(static_cast<long unsigned int>(
- write(fd, line.c_str(), line.size())) == line.size());
+ size_t n;
+ TEST_PCHECK((n = write(fd, line.c_str(), line.size())) != -1);
+ TEST_CHECK(n == line.size());
}),
IsPosixErrorOkAndHolds(0));
}
diff --git a/test/syscalls/linux/pselect.cc b/test/syscalls/linux/pselect.cc
index 4e43c4d7f..e490a987d 100644
--- a/test/syscalls/linux/pselect.cc
+++ b/test/syscalls/linux/pselect.cc
@@ -88,7 +88,7 @@ TEST_F(PselectTest, ZeroTimeout) {
// If random S/R interrupts the pselect, SIGALRM may be delivered before pselect
// restarts, causing the pselect to hang forever.
-TEST_F(PselectTest, NoTimeout_NoRandomSave) {
+TEST_F(PselectTest, NoTimeout) {
// When there's no timeout, pselect may never return so set a timer.
SetTimer(absl::Milliseconds(100));
// See that we get interrupted by the timer.
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index d1d7c6f84..2d9fec371 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -1708,8 +1708,7 @@ INSTANTIATE_TEST_SUITE_P(TraceExec, PtraceExecveTest, ::testing::Bool());
// This test has expectations on when syscall-enter/exit-stops occur that are
// violated if saving occurs, since saving interrupts all syscalls, causing
// premature syscall-exit.
-TEST(PtraceTest,
- ExitWhenParentIsNotTracer_Syscall_TraceVfork_TraceVforkDone_NoRandomSave) {
+TEST(PtraceTest, ExitWhenParentIsNotTracer_Syscall_TraceVfork_TraceVforkDone) {
constexpr int kExitTraceeExitCode = 99;
pid_t const child_pid = fork();
@@ -2006,7 +2005,7 @@ TEST(PtraceTest, Sysemu_PokeUser) {
}
// This test also cares about syscall-exit-stop.
-TEST(PtraceTest, ERESTART_NoRandomSave) {
+TEST(PtraceTest, ERESTART) {
constexpr int kSigno = SIGUSR1;
pid_t const child_pid = fork();
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 32924466f..69616b400 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -514,10 +514,7 @@ TEST_P(RawSocketTest, SetSocketRecvBuf) {
SyscallSucceeds());
// Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
- // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior.
- if (!IsRunningOnGvisor()) {
- quarter_sz *= 2;
- }
+ quarter_sz *= 2;
ASSERT_EQ(quarter_sz, val);
}
@@ -713,12 +710,7 @@ TEST_P(RawSocketTest, RecvBufLimits) {
}
// Now set the limit to min * 2.
- int new_rcv_buf_sz = min * 4;
- if (!IsRunningOnGvisor()) {
- // Linux doubles the value specified so just set to min.
- new_rcv_buf_sz = min * 2;
- }
-
+ int new_rcv_buf_sz = min * 2;
ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz,
sizeof(new_rcv_buf_sz)),
SyscallSucceeds());
diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc
index 087262535..7056342d7 100644
--- a/test/syscalls/linux/read.cc
+++ b/test/syscalls/linux/read.cc
@@ -97,7 +97,7 @@ TEST_F(ReadTest, DevNullReturnsEof) {
const int kReadSize = 128 * 1024;
// Do not allow random save as it could lead to partial reads.
-TEST_F(ReadTest, CanReadFullyFromDevZero_NoRandomSave) {
+TEST_F(ReadTest, CanReadFullyFromDevZero) {
int fd;
ASSERT_THAT(fd = open("/dev/zero", O_RDONLY), SyscallSucceeds());
diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc
index 86808d255..a50d98d21 100644
--- a/test/syscalls/linux/readv.cc
+++ b/test/syscalls/linux/readv.cc
@@ -267,7 +267,7 @@ TEST_F(ReadvTest, ReadvWithOpath) {
// This test depends on the maximum extent of a single readv() syscall, so
// we can't tolerate interruption from saving.
-TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) {
+TEST(ReadvTestNoFixture, TruncatedAtMax) {
// Ensure that we won't be interrupted by ITIMER_PROF. This is particularly
// important in environments where automated profiling tools may start
// ITIMER_PROF automatically.
diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc
index be2364fb8..d74096ded 100644
--- a/test/syscalls/linux/select.cc
+++ b/test/syscalls/linux/select.cc
@@ -98,7 +98,7 @@ TEST_F(SelectTest, ZeroTimeout) {
// If random S/R interrupts the select, SIGALRM may be delivered before select
// restarts, causing the select to hang forever.
-TEST_F(SelectTest, NoTimeout_NoRandomSave) {
+TEST_F(SelectTest, NoTimeout) {
// When there's no timeout, select may never return so set a timer.
SetTimer(absl::Milliseconds(100));
// See that we get interrupted by the timer.
@@ -118,7 +118,7 @@ TEST_F(SelectTest, InvalidTimeoutNegative) {
//
// If random S/R interrupts the select, SIGALRM may be delivered before select
// restarts, causing the select to hang forever.
-TEST_F(SelectTest, InterruptedBySignal_NoRandomSave) {
+TEST_F(SelectTest, InterruptedBySignal) {
absl::Duration duration(absl::Seconds(5));
struct timeval timeout = absl::ToTimeval(duration);
SetTimer(absl::Milliseconds(100));
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 28f51a3bf..207377efb 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -234,14 +234,6 @@ TEST(SemaphoreTest, SemTimedOpBlock) {
AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
ASSERT_THAT(sem.get(), SyscallSucceeds());
- ScopedThread th([&sem] {
- absl::SleepFor(absl::Milliseconds(100));
-
- struct sembuf buf = {};
- buf.sem_op = 1;
- ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
- });
-
struct sembuf buf = {};
buf.sem_op = -1;
struct timespec timeout = {};
@@ -295,7 +287,7 @@ TEST(SemaphoreTest, SemOpSimple) {
// Tests that semaphore can be removed while there are waiters.
// NoRandomSave: Test relies on timing that random save throws off.
-TEST(SemaphoreTest, SemOpRemoveWithWaiter_NoRandomSave) {
+TEST(SemaphoreTest, SemOpRemoveWithWaiter) {
AutoSem sem(semget(IPC_PRIVATE, 2, 0600 | IPC_CREAT));
ASSERT_THAT(sem.get(), SyscallSucceeds());
@@ -716,7 +708,7 @@ TEST(SemaphoreTest, SemopGetzcntOnSetRemoval) {
EXPECT_THAT(semctl(semid, 0, GETZCNT), SyscallFailsWithErrno(EINVAL));
}
-TEST(SemaphoreTest, SemopGetzcntOnSignal_NoRandomSave) {
+TEST(SemaphoreTest, SemopGetzcntOnSignal) {
AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
ASSERT_THAT(sem.get(), SyscallSucceeds());
ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds());
@@ -821,7 +813,7 @@ TEST(SemaphoreTest, SemopGetncntOnSetRemoval) {
EXPECT_THAT(semctl(semid, 0, GETNCNT), SyscallFailsWithErrno(EINVAL));
}
-TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) {
+TEST(SemaphoreTest, SemopGetncntOnSignal) {
AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
ASSERT_THAT(sem.get(), SyscallSucceeds());
ASSERT_EQ(semctl(sem.get(), 0, GETNCNT), 0);
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index 93b3a94f1..bea4ee71c 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -654,7 +654,7 @@ TEST(SendFileTest, SendFileToPipe) {
SyscallSucceedsWithValue(kDataSize));
}
-TEST(SendFileTest, SendFileToSelf_NoRandomSave) {
+TEST(SendFileTest, SendFileToSelf) {
int rawfd;
ASSERT_THAT(rawfd = memfd_create("memfd", 0), SyscallSucceeds());
const FileDescriptor fd(rawfd);
@@ -675,7 +675,7 @@ TEST(SendFileTest, SendFileToSelf_NoRandomSave) {
static volatile int signaled = 0;
void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; }
-TEST(SendFileTest, ToEventFDDoesNotSpin_NoRandomSave) {
+TEST(SendFileTest, ToEventFDDoesNotSpin) {
FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
// Write the maximum value of an eventfd to a file.
diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc
index 4f8afff15..21651a697 100644
--- a/test/syscalls/linux/sigtimedwait.cc
+++ b/test/syscalls/linux/sigtimedwait.cc
@@ -52,7 +52,7 @@ TEST(SigtimedwaitTest, InvalidTimeout) {
// No random save as the test relies on alarm timing. Cooperative save tests
// already cover the save between alarm and wait.
-TEST(SigtimedwaitTest, AlarmReturnsAlarm_NoRandomSave) {
+TEST(SigtimedwaitTest, AlarmReturnsAlarm) {
struct itimerval itv = {};
itv.it_value.tv_sec = kAlarmSecs;
const auto itimer_cleanup =
@@ -69,7 +69,7 @@ TEST(SigtimedwaitTest, AlarmReturnsAlarm_NoRandomSave) {
// No random save as the test relies on alarm timing. Cooperative save tests
// already cover the save between alarm and wait.
-TEST(SigtimedwaitTest, NullTimeoutReturnsEINTR_NoRandomSave) {
+TEST(SigtimedwaitTest, NullTimeoutReturnsEINTR) {
struct sigaction sa;
sa.sa_sigaction = NoopHandler;
sigfillset(&sa.sa_mask);
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index b616c2c87..7b966484d 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -47,7 +47,7 @@ TEST(SocketTest, ProtocolUnix) {
{AF_UNIX, SOCK_SEQPACKET, PF_UNIX},
{AF_UNIX, SOCK_DGRAM, PF_UNIX},
};
- for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ for (size_t i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
ASSERT_NO_ERRNO_AND_VALUE(
Socket(tests[i].domain, tests[i].type, tests[i].protocol));
}
@@ -60,7 +60,7 @@ TEST(SocketTest, ProtocolInet) {
{AF_INET, SOCK_DGRAM, IPPROTO_UDP},
{AF_INET, SOCK_STREAM, IPPROTO_TCP},
};
- for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ for (size_t i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
ASSERT_NO_ERRNO_AND_VALUE(
Socket(tests[i].domain, tests[i].type, tests[i].protocol));
}
@@ -111,7 +111,7 @@ TEST(SocketTest, UnixSocketStatFS) {
EXPECT_EQ(st.f_namelen, NAME_MAX);
}
-TEST(SocketTest, UnixSCMRightsOnlyPassedOnce_NoRandomSave) {
+TEST(SocketTest, UnixSCMRightsOnlyPassedOnce) {
const DisableSave ds;
int sockets[2];
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
index f8a0a80f2..3b108cbd3 100644
--- a/test/syscalls/linux/socket_bind_to_device_distribution.cc
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -141,9 +141,8 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
endpoint.bind_to_device.c_str(),
endpoint.bind_to_device.size() + 1),
SyscallSucceeds());
- ASSERT_THAT(
- bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
// On the first bind we need to determine which port was bound.
@@ -154,8 +153,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(
- getsockname(listener_fds[0].get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -168,7 +166,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
std::vector<std::unique_ptr<ScopedThread>> listen_threads(
listener_fds.size());
- for (long unsigned int i = 0; i < listener_fds.size(); i++) {
+ for (size_t i = 0; i < listener_fds.size(); i++) {
listen_threads[i] = absl::make_unique<ScopedThread>(
[&listener_fds, &accept_counts, &connects_received, i,
kConnectAttempts]() {
@@ -207,10 +205,9 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
for (int32_t i = 0; i < kConnectAttempts; i++) {
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(
- RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(fd.get(), AsSockAddr(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
@@ -221,7 +218,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
listen_thread->Join();
}
// Check that connections are distributed correctly among listening sockets.
- for (long unsigned int i = 0; i < accept_counts.size(); i++) {
+ for (size_t i = 0; i < accept_counts.size(); i++) {
EXPECT_THAT(
accept_counts[i],
EquivalentWithin(static_cast<int>(kConnectAttempts *
@@ -267,9 +264,8 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
endpoint.bind_to_device.c_str(),
endpoint.bind_to_device.size() + 1),
SyscallSucceeds());
- ASSERT_THAT(
- bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
// On the first bind we need to determine which port was bound.
if (listener_fds.size() > 1) {
@@ -279,8 +275,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(
- getsockname(listener_fds[0].get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -294,7 +289,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
listener_fds.size());
- for (long unsigned int i = 0; i < listener_fds.size(); i++) {
+ for (size_t i = 0; i < listener_fds.size(); i++) {
receiver_threads[i] = absl::make_unique<ScopedThread>(
[&listener_fds, &packets_per_socket, &packets_received, i]() {
do {
@@ -302,9 +297,9 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
socklen_t addrlen = sizeof(addr);
int data;
- auto ret = RetryEINTR(recvfrom)(
- listener_fds[i].get(), &data, sizeof(data), 0,
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+ auto ret =
+ RetryEINTR(recvfrom)(listener_fds[i].get(), &data, sizeof(data),
+ 0, AsSockAddr(&addr), &addrlen);
if (packets_received < kConnectAttempts) {
ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
@@ -322,10 +317,10 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
// A response is required to synchronize with the main thread,
// otherwise the main thread can send more than can fit into receive
// queues.
- EXPECT_THAT(RetryEINTR(sendto)(
- listener_fds[i].get(), &data, sizeof(data), 0,
- reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceedsWithValue(sizeof(data)));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(listener_fds[i].get(), &data, sizeof(data),
+ 0, AsSockAddr(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
} while (packets_received < kConnectAttempts);
// Shutdown all sockets to wake up other threads.
@@ -339,8 +334,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
FileDescriptor const fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len),
+ AsSockAddr(&conn_addr), connector.addr_len),
SyscallSucceedsWithValue(sizeof(i)));
int data;
EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
@@ -352,7 +346,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) {
receiver_thread->Join();
}
// Check that packets are distributed correctly among listening sockets.
- for (long unsigned int i = 0; i < packets_per_socket.size(); i++) {
+ for (size_t i = 0; i < packets_per_socket.size(); i++) {
EXPECT_THAT(
packets_per_socket[i],
EquivalentWithin(static_cast<int>(kConnectAttempts *
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 597b5bcb1..9a6b089f6 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -190,8 +190,7 @@ TEST_P(DualStackSocketTest, AddressOperations) {
if (sockname) {
sockaddr_storage sock_addr;
socklen_t addrlen = sizeof(sock_addr);
- ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr),
- &addrlen),
+ ASSERT_THAT(getsockname(fd.get(), AsSockAddr(&sock_addr), &addrlen),
SyscallSucceeds());
ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
@@ -200,24 +199,23 @@ TEST_P(DualStackSocketTest, AddressOperations) {
if (operation == Operation::SendTo) {
EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6);
EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32))
- << OperationToString(operation) << " getsocknam="
- << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+ << OperationToString(operation)
+ << " getsocknam=" << GetAddrStr(AsSockAddr(&sock_addr));
EXPECT_NE(sock_addr_in6->sin6_port, 0);
} else if (IN6_IS_ADDR_V4MAPPED(
reinterpret_cast<const sockaddr_in6*>(addr_in)
->sin6_addr.s6_addr32)) {
EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32))
- << OperationToString(operation) << " getsocknam="
- << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+ << OperationToString(operation)
+ << " getsocknam=" << GetAddrStr(AsSockAddr(&sock_addr));
}
}
if (peername) {
sockaddr_storage peer_addr;
socklen_t addrlen = sizeof(peer_addr);
- ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr),
- &addrlen),
+ ASSERT_THAT(getpeername(fd.get(), AsSockAddr(&peer_addr), &addrlen),
SyscallSucceeds());
ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
@@ -227,8 +225,8 @@ TEST_P(DualStackSocketTest, AddressOperations) {
EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(
reinterpret_cast<const sockaddr_in6*>(&peer_addr)
->sin6_addr.s6_addr32))
- << OperationToString(operation) << " getpeername="
- << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr));
+ << OperationToString(operation)
+ << " getpeername=" << GetAddrStr(AsSockAddr(&peer_addr));
}
}
}
@@ -265,16 +263,15 @@ void tcpSimpleConnectTest(TestAddress const& listener,
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
if (!unbound) {
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
}
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -284,8 +281,7 @@ void tcpSimpleConnectTest(TestAddress const& listener,
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -331,9 +327,9 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
@@ -341,8 +337,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
const uint16_t port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -357,8 +352,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
for (int i = 0; i < kBacklog; i++) {
auto client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(RetryEINTR(connect)(client.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(client.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
}
@@ -380,15 +374,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -402,8 +395,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
for (int i = 0; i < kFDs; i++) {
auto client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(RetryEINTR(connect)(client.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(client.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
@@ -420,8 +412,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
- bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
+ bind(new_listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Check that subsequent connection attempts receive a RST.
@@ -431,8 +422,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
for (int i = 0; i < kFDs; i++) {
auto client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(RetryEINTR(connect)(client.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(client.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallFailsWithErrno(ECONNREFUSED));
}
@@ -452,15 +442,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -471,8 +460,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
for (int i = 0; i < kFDs; i++) {
auto client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
- int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
+ int ret = connect(client.get(), AsSockAddr(&conn_addr), connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
@@ -484,93 +472,160 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
}
}
-void TestListenWhileConnect(const TestParam& param,
- void (*stopListen)(FileDescriptor&)) {
+void TestHangupDuringConnect(const TestParam& param,
+ void (*hangup)(FileDescriptor&)) {
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
- constexpr int kBacklog = 2;
- // Linux completes one more connection than the listen backlog argument.
- // To ensure that there is at least one client connection that stays in
- // connecting state, keep 2 more client connections than the listen backlog.
- // gVisor differs in this behavior though, gvisor.dev/issue/3153.
- constexpr int kClients = kBacklog + 2;
+ for (int i = 0; i < 100; i++) {
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), 0), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ // Connect asynchronously and immediately hang up the listener.
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+
+ hangup(listen_fd);
+
+ // Wait for the connection to close.
+ struct pollfd pfd = {
+ .fd = client.get(),
+ };
+ constexpr int kTimeout = 10000;
+ int n = poll(&pfd, 1, kTimeout);
+ ASSERT_GE(n, 0) << strerror(errno);
+ ASSERT_EQ(n, 1);
+ ASSERT_EQ(pfd.revents, POLLHUP | POLLERR);
+ ASSERT_EQ(close(client.release()), 0) << strerror(errno);
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenCloseDuringConnect) {
+ TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownDuringConnect) {
+ TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
+}
+
+void TestListenHangupConnectingRead(const TestParam& param,
+ void (*hangup)(FileDescriptor&)) {
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ // This test is only interested in deterministically getting a socket in
+ // connecting state. For that, we use a listen backlog of zero which would
+ // mean there is exactly one connection that gets established and is enqueued
+ // to the accept queue. We poll on the listener to ensure that is enqueued.
+ // After that the subsequent client connect will stay in connecting state as
+ // the accept queue is full.
+ constexpr int kBacklog = 0;
ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- std::vector<FileDescriptor> clients;
- for (int i = 0; i < kClients; i++) {
- FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
- int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
- if (ret != 0) {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- clients.push_back(std::move(client));
- }
+ FileDescriptor established_client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(established_client.get(), AsSockAddr(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Ensure that the accept queue has the completed connection.
+ constexpr int kTimeout = 10000;
+ pollfd pfd = {
+ .fd = listen_fd.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+
+ FileDescriptor connecting_client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ // Keep the last client in connecting state.
+ int ret = connect(connecting_client.get(), AsSockAddr(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
- stopListen(listen_fd);
+ hangup(listen_fd);
- for (auto& client : clients) {
- constexpr int kTimeout = 10000;
+ std::array<std::pair<int, int>, 2> sockets = {
+ std::make_pair(established_client.get(), ECONNRESET),
+ std::make_pair(connecting_client.get(), ECONNREFUSED),
+ };
+ for (size_t i = 0; i < sockets.size(); i++) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ auto [fd, expected_errno] = sockets[i];
pollfd pfd = {
- .fd = client.get(),
- .events = POLLIN,
+ .fd = fd,
};
- // When the listening socket is closed, then we expect the remote to reset
- // the connection.
- ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
- ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+ // When the listening socket is closed, the peer would reset the connection.
+ EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(pfd.revents, POLLHUP | POLLERR);
char c;
- // Subsequent read can fail with:
- // ECONNRESET: If the client connection was established and was reset by the
- // remote.
- // ECONNREFUSED: If the client connection failed to be established.
- ASSERT_THAT(read(client.get(), &c, sizeof(c)),
- AnyOf(SyscallFailsWithErrno(ECONNRESET),
- SyscallFailsWithErrno(ECONNREFUSED)));
- // The last client connection would be in connecting (SYN_SENT) state.
- if (client.get() == clients[kClients - 1].get()) {
- ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno);
- }
+ EXPECT_THAT(read(fd, &c, sizeof(c)), SyscallFailsWithErrno(expected_errno));
}
}
-TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
- TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+TEST_P(SocketInetLoopbackTest, TCPListenCloseConnectingRead) {
+ TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) {
ASSERT_THAT(close(f.release()), SyscallSucceeds());
});
}
-TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
- TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownConnectingRead) {
+ TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) {
ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
});
}
-// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/
+// TODO(b/157236388): Remove once bug is fixed. Test fails w/
// random save as established connections which can't be delivered to the accept
// queue because the queue is full are not correctly delivered after restore
// causing the last accept to timeout on the restore.
-TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPAcceptBacklogSizes) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -580,21 +635,70 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ std::array<int, 3> backlogs = {-1, 0, 1};
+ for (auto& backlog : backlogs) {
+ ASSERT_THAT(listen(listen_fd.get(), backlog), SyscallSucceeds());
+
+ int expected_accepts;
+ if (backlog < 0) {
+ expected_accepts = 1024;
+ } else {
+ expected_accepts = backlog + 1;
+ }
+ for (int i = 0; i < expected_accepts; i++) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ // Connect to the listening socket.
+ const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ const FileDescriptor accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ }
+ }
+}
+
+// TODO(b/157236388): Remove once bug is fixed. Test fails w/
+// random save as established connections which can't be delivered to the accept
+// queue because the queue is full are not correctly delivered after restore
+// causing the last accept to timeout on the restore.
+TEST_P(SocketInetLoopbackTest, TCPBacklog) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
constexpr int kBacklogSize = 2;
ASSERT_THAT(listen(listen_fd.get(), kBacklogSize), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
int i = 0;
while (1) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
int ret;
// Connect to the listening socket.
@@ -602,8 +706,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ret = connect(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
+ ret = connect(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
pollfd pfd = {
@@ -620,103 +723,130 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
i++;
}
+ int client_conns = i;
+ int accepted_conns = 0;
for (; i != 0; i--) {
- // Accept the connection.
- //
- // We have to assign a name to the accepted socket, as unamed temporary
- // objects are destructed upon full evaluation of the expression it is in,
- // potentially causing the connecting socket to fail to shutdown properly.
- auto accepted =
- ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ pollfd pfd = {
+ .fd = listen_fd.get(),
+ .events = POLLIN,
+ };
+ // Look for incoming connections to accept. The last connect request could
+ // be established from the client side, but the ACK of the handshake could
+ // be dropped by the listener if the accept queue was filled up by the
+ // previous connect.
+ int ret;
+ ASSERT_THAT(ret = poll(&pfd, 1, 3000), SyscallSucceeds());
+ if (ret == 0) break;
+ if (pfd.revents == POLLIN) {
+ // Accept the connection.
+ //
+ // We have to assign a name to the accepted socket, as unamed temporary
+ // objects are destructed upon full evaluation of the expression it is in,
+ // potentially causing the connecting socket to fail to shutdown properly.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ accepted_conns++;
+ }
}
+ // We should accept at least listen backlog + 1 connections. As the stack is
+ // enqueuing established connections to the accept queue, newer SYNs could
+ // still be replied to causing those client connections would be accepted as
+ // we start dequeuing the queue.
+ ASSERT_GE(accepted_conns, kBacklogSize + 1);
+ ASSERT_GE(client_conns, accepted_conns);
}
-// Test if the stack completes atmost listen backlog number of client
-// connections. It exercises the path of the stack that enqueues completed
-// connections to accept queue vs new incoming SYNs.
-TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) {
- const auto& param = GetParam();
- const TestAddress& listener = param.listener;
- const TestAddress& connector = param.connector;
+// TODO(b/157236388): Remove once bug is fixed. Test fails w/
+// random save as established connections which can't be delivered to the accept
+// queue because the queue is full are not correctly delivered after restore
+// causing the last accept to timeout on the restore.
+TEST_P(SocketInetLoopbackTest, TCPBacklogAcceptAll) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
constexpr int kBacklog = 1;
- // Keep the number of client connections more than the listen backlog.
- // Linux completes one more connection than the listen backlog argument.
- // gVisor differs in this behavior though, gvisor.dev/issue/3153.
- int kClients = kBacklog + 2;
- if (IsRunningOnGvisor()) {
- kClients--;
- }
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
- // Run the following test for few iterations to test race between accept queue
- // getting filled with incoming SYNs.
- for (int num = 0; num < 10; num++) {
- FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
- sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
- ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
- socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(
- getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- &addrlen),
- SyscallSucceeds());
- uint16_t const port =
- ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
- sockaddr_storage conn_addr = connector.addr;
- ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- std::vector<FileDescriptor> clients;
- // Issue multiple non-blocking client connects.
- for (int i = 0; i < kClients; i++) {
- FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
- int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
- if (ret != 0) {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- }
- clients.push_back(std::move(client));
+ // Fill up the accept queue and trigger more client connections which would be
+ // waiting to be accepted.
+ std::array<FileDescriptor, kBacklog + 1> established_clients;
+ for (auto& fd : established_clients) {
+ fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(fd.get(), AsSockAddr(&conn_addr), connector.addr_len),
+ SyscallSucceeds());
+ }
+ std::array<FileDescriptor, kBacklog> waiting_clients;
+ for (auto& fd : waiting_clients) {
+ fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(fd.get(), AsSockAddr(&conn_addr), connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
+ }
- // Now that client connects are issued, wait for the accept queue to get
- // filled and ensure no new client connection is completed.
- for (int i = 0; i < kClients; i++) {
- pollfd pfd = {
- .fd = clients[i].get(),
- .events = POLLOUT,
- };
- if (i < kClients - 1) {
- // Poll for client side connection completions with a large timeout.
- // We cannot poll on the listener side without calling accept as poll
- // stays level triggered with non-zero accept queue length.
- //
- // Client side poll would not guarantee that the completed connection
- // has been enqueued in to the acccept queue, but the fact that the
- // listener ACKd the SYN, means that it cannot complete any new incoming
- // SYNs when it has already ACKd for > backlog number of SYNs.
- ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1))
- << "num=" << num << " i=" << i << " kClients=" << kClients;
- ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i;
- } else {
- // Now that we expect accept queue filled up, ensure that the last
- // client connection never completes with a smaller poll timeout.
- ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0))
- << "num=" << num << " i=" << i;
- }
+ auto accept_connection = [&]() {
+ constexpr int kTimeout = 10000;
+ pollfd pfd = {
+ .fd = listen_fd.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+ // Accept the connection.
+ //
+ // We have to assign a name to the accepted socket, as unamed temporary
+ // objects are destructed upon full evaluation of the expression it is in,
+ // potentially causing the connecting socket to fail to shutdown properly.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ };
- ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0))
- << "num=" << num << " i=" << i;
- }
- clients.clear();
- // We close the listening side and open a new listener. We could instead
- // drain the accept queue by calling accept() and reuse the listener, but
- // that is racy as the retransmitted SYNs could get ACKd as we make room in
- // the accept queue.
- ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0));
+ // Ensure that we accept all client connections. The waiting connections would
+ // get enqueued as we drain the accept queue.
+ for (int i = 0; i < std::size(established_clients); i++) {
+ SCOPED_TRACE(absl::StrCat("established clients i=", i));
+ accept_connection();
+ }
+
+ // The waiting client connections could be in one of these 2 states:
+ // (1) SYN_SENT: if the SYN was dropped because accept queue was full
+ // (2) ESTABLISHED: if the listener sent back a SYNACK, but may have dropped
+ // the ACK from the client if the accept queue was full (send out a data to
+ // re-send that ACK, to address that case).
+ for (int i = 0; i < std::size(waiting_clients); i++) {
+ SCOPED_TRACE(absl::StrCat("waiting clients i=", i));
+ constexpr int kTimeout = 10000;
+ pollfd pfd = {
+ .fd = waiting_clients[i].get(),
+ .events = POLLOUT,
+ };
+ EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(pfd.revents, POLLOUT);
+ char c;
+ EXPECT_THAT(RetryEINTR(send)(waiting_clients[i].get(), &c, sizeof(c), 0),
+ SyscallSucceedsWithValue(sizeof(c)));
+ accept_connection();
}
}
@@ -728,7 +858,7 @@ TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) {
//
// TCP timers are not S/R today, this can cause this test to be flaky when run
// under random S/R due to timer being reset on a restore.
-TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPFinWait2Test) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
@@ -737,15 +867,14 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
@@ -763,8 +892,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -776,8 +904,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
sockaddr_storage conn_bound_addr;
socklen_t conn_addrlen = connector.addr_len;
ASSERT_THAT(
- getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
- &conn_addrlen),
+ getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen),
SyscallSucceeds());
// close the connecting FD to trigger FIN_WAIT2 on the connected fd.
@@ -792,8 +919,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
// be restarted causing the final bind/connect to fail.
DisableSave ds;
- ASSERT_THAT(bind(conn_fd2.get(),
- reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
+ ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen),
SyscallFailsWithErrno(EADDRINUSE));
// Sleep for a little over the linger timeout to reduce flakiness in
@@ -802,10 +928,9 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
ds.reset();
- ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
- conn_addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(
+ RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen),
+ SyscallSucceeds());
}
// TCPLinger2TimeoutAfterClose creates a pair of connected sockets
@@ -815,7 +940,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
//
// TCP timers are not S/R today, this can cause this test to be flaky when run
// under random S/R due to timer being reset on a restore.
-TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
@@ -824,15 +949,14 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
@@ -844,8 +968,7 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -857,8 +980,7 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
sockaddr_storage conn_bound_addr;
socklen_t conn_addrlen = connector.addr_len;
ASSERT_THAT(
- getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
- &conn_addrlen),
+ getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen),
SyscallSucceeds());
// Disable cooperative saves after this point as TCP timers are not restored
@@ -884,13 +1006,11 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(bind(conn_fd2.get(),
- reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
- SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
- conn_addrlen),
+ ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen),
SyscallSucceeds());
+ ASSERT_THAT(
+ RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen),
+ SyscallSucceeds());
}
// TCPResetAfterClose creates a pair of connected sockets then closes
@@ -906,15 +1026,14 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
@@ -926,8 +1045,7 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -975,15 +1093,14 @@ void setupTimeWaitClose(const TestAddress* listener,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
}
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(listen_addr),
- listener->addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(listen_addr), listener->addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener->addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
@@ -1005,8 +1122,7 @@ void setupTimeWaitClose(const TestAddress* listener,
sockaddr_storage conn_addr = connector->addr;
ASSERT_NO_ERRNO(SetAddrPort(connector->family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector->addr_len),
SyscallSucceeds());
@@ -1017,8 +1133,7 @@ void setupTimeWaitClose(const TestAddress* listener,
// Get the address/port bound by the connecting socket.
socklen_t conn_addrlen = connector->addr_len;
ASSERT_THAT(
- getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(conn_bound_addr),
- &conn_addrlen),
+ getsockname(conn_fd.get(), AsSockAddr(conn_bound_addr), &conn_addrlen),
SyscallSucceeds());
FileDescriptor active_closefd, passive_closefd;
@@ -1064,7 +1179,7 @@ void setupTimeWaitClose(const TestAddress* listener,
//
// Test re-binding of client and server bound addresses when the older
// connection is in TIME_WAIT.
-TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest) {
auto const& param = GetParam();
sockaddr_storage listen_addr, conn_bound_addr;
listen_addr = param.listener.addr;
@@ -1075,19 +1190,18 @@ TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) {
// bound by the conn_fd as it never entered TIME_WAIT.
const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr),
param.connector.addr_len),
SyscallSucceeds());
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- param.listener.addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
}
-TEST_P(SocketInetLoopbackTest,
- TCPPassiveCloseNoTimeWaitReuseTest_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitReuseTest) {
auto const& param = GetParam();
sockaddr_storage listen_addr, conn_bound_addr;
listen_addr = param.listener.addr;
@@ -1099,9 +1213,9 @@ TEST_P(SocketInetLoopbackTest,
ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- param.listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Now bind and connect new socket and verify that we can immediately rebind
@@ -1111,7 +1225,7 @@ TEST_P(SocketInetLoopbackTest,
ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr),
param.connector.addr_len),
SyscallSucceeds());
@@ -1119,13 +1233,12 @@ TEST_P(SocketInetLoopbackTest,
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(param.listener.family(), listen_addr));
sockaddr_storage conn_addr = param.connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(param.connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
param.connector.addr_len),
SyscallSucceeds());
}
-TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest) {
auto const& param = GetParam();
sockaddr_storage listen_addr, conn_bound_addr;
listen_addr = param.listener.addr;
@@ -1134,12 +1247,12 @@ TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) {
FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr),
param.connector.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
-TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest) {
auto const& param = GetParam();
sockaddr_storage listen_addr, conn_bound_addr;
listen_addr = param.listener.addr;
@@ -1150,7 +1263,7 @@ TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest_NoRandomSave) {
ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr),
param.connector.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -1164,15 +1277,14 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
const uint16_t port =
@@ -1190,8 +1302,7 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -1218,17 +1329,16 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
{
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(
- getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- &addrlen),
+ getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
}
@@ -1244,8 +1354,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
// TODO(b/157236388): Reenable Cooperative S/R once bug is fixed.
DisableSave ds;
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -1272,8 +1381,8 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
sockaddr_storage accept_addr;
socklen_t addrlen = sizeof(accept_addr);
- auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept(
- listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen));
+ auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Accept(listen_fd.get(), AsSockAddr(&accept_addr), &addrlen));
ASSERT_EQ(addrlen, listener.addr_len);
// Wait for accept_fd to process the RST.
@@ -1311,15 +1420,14 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
sockaddr_storage peer_addr;
socklen_t addrlen = sizeof(peer_addr);
// The socket is not connected anymore and should return ENOTCONN.
- ASSERT_THAT(getpeername(accept_fd.get(),
- reinterpret_cast<sockaddr*>(&peer_addr), &addrlen),
+ ASSERT_THAT(getpeername(accept_fd.get(), AsSockAddr(&peer_addr), &addrlen),
SyscallFailsWithErrno(ENOTCONN));
}
}
// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
// saved. Enable S/R once issue is fixed.
-TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPDeferAccept) {
// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
// saved. Enable S/R issue is fixed.
DisableSave ds;
@@ -1332,15 +1440,14 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
const uint16_t port =
@@ -1358,8 +1465,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -1401,7 +1507,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
// saved. Enable S/R once issue is fixed.
-TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout) {
// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
// saved. Enable S/R once issue is fixed.
DisableSave ds;
@@ -1414,15 +1520,14 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) {
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
const uint16_t port =
@@ -1440,8 +1545,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
- reinterpret_cast<sockaddr*>(&conn_addr),
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr),
connector.addr_len),
SyscallSucceeds());
@@ -1507,9 +1611,9 @@ INSTANTIATE_TEST_SUITE_P(
using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>;
-// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is
+// TODO(gvisor.dev/issue/940): Remove when portHint/stack.Seed is
// saved/restored.
-TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
+TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -1529,9 +1633,8 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(
- bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
// On the first bind we need to determine which port was bound.
@@ -1542,8 +1645,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(
- getsockname(listener_fds[0].get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -1601,10 +1703,9 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
for (int32_t i = 0; i < kConnectAttempts; i++) {
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(
- RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(fd.get(), AsSockAddr(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
@@ -1622,7 +1723,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
-TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -1641,9 +1742,8 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(
- bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
// On the first bind we need to determine which port was bound.
if (i != 0) {
@@ -1653,8 +1753,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(
- getsockname(listener_fds[0].get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -1677,9 +1776,9 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
socklen_t addrlen = sizeof(addr);
int data;
- auto ret = RetryEINTR(recvfrom)(
- listener_fds[i].get(), &data, sizeof(data), 0,
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+ auto ret =
+ RetryEINTR(recvfrom)(listener_fds[i].get(), &data, sizeof(data),
+ 0, AsSockAddr(&addr), &addrlen);
if (packets_received < kConnectAttempts) {
ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
@@ -1697,10 +1796,10 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
// A response is required to synchronize with the main thread,
// otherwise the main thread can send more than can fit into receive
// queues.
- EXPECT_THAT(RetryEINTR(sendto)(
- listener_fds[i].get(), &data, sizeof(data), 0,
- reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceedsWithValue(sizeof(data)));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(listener_fds[i].get(), &data, sizeof(data),
+ 0, AsSockAddr(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
} while (packets_received < kConnectAttempts);
// Shutdown all sockets to wake up other threads.
@@ -1713,10 +1812,10 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
for (int i = 0; i < kConnectAttempts; i++) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
- EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len),
- SyscallSucceedsWithValue(sizeof(i)));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, AsSockAddr(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
int data;
EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -1735,7 +1834,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
-TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -1757,9 +1856,8 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(
- bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
// On the first bind we need to determine which port was bound.
if (i != 0) {
@@ -1769,8 +1867,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
ASSERT_THAT(
- getsockname(listener_fds[0].get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -1787,8 +1884,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
client_fds[i] =
ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len),
+ AsSockAddr(&conn_addr), connector.addr_len),
SyscallSucceedsWithValue(sizeof(i)));
}
ds.reset();
@@ -1797,8 +1893,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
// not been change after save/restore.
for (int i = 0; i < kConnectAttempts; i++) {
EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
- reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len),
+ AsSockAddr(&conn_addr), connector.addr_len),
SyscallSucceedsWithValue(sizeof(i)));
}
@@ -1826,9 +1921,8 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
struct sockaddr_storage addr = {};
socklen_t addrlen = sizeof(addr);
int data;
- EXPECT_THAT(RetryEINTR(recvfrom)(
- fd, &data, sizeof(data), 0,
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ EXPECT_THAT(RetryEINTR(recvfrom)(fd, &data, sizeof(data), 0,
+ AsSockAddr(&addr), &addrlen),
SyscallSucceedsWithValue(sizeof(data)));
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr));
@@ -1882,14 +1976,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) {
sockaddr_storage addr_dual = test_addr_dual.addr;
const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_dual.family(), param.type, 0));
- ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
- test_addr_dual.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len),
+ SyscallSucceeds());
// Get the port that we bound.
socklen_t addrlen = test_addr_dual.addr_len;
- ASSERT_THAT(getsockname(fd_dual.get(),
- reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
@@ -1900,8 +1993,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
const FileDescriptor fd_v6 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
- int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len);
+ int ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len);
if (ret == -1 && errno == EADDRINUSE) {
// Port may have been in use.
ASSERT_LT(i, 100); // Give up after 100 tries.
@@ -1916,8 +2008,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
const FileDescriptor fd_v4 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
- test_addr_v4.addr_len),
+ ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// No need to try again.
@@ -1934,14 +2025,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) {
sockaddr_storage addr_dual = test_addr_dual.addr;
const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_dual.family(), param.type, 0));
- ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
- test_addr_dual.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len),
+ SyscallSucceeds());
// Get the port that we bound.
socklen_t addrlen = test_addr_dual.addr_len;
- ASSERT_THAT(getsockname(fd_dual.get(),
- reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
@@ -1952,8 +2042,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
const FileDescriptor fd_v6 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
- int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len);
+ int ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len);
if (ret == -1 && errno == EADDRINUSE) {
// Port may have been in use.
ASSERT_LT(i, 100); // Give up after 100 tries.
@@ -1968,8 +2057,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
const FileDescriptor fd_v4 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
- test_addr_v4.addr_len),
+ ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// No need to try again.
@@ -1985,14 +2073,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) {
sockaddr_storage addr_dual = test_addr_dual.addr;
const FileDescriptor fd_dual =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
- ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
- test_addr_dual.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len),
+ SyscallSucceeds());
// Get the port that we bound.
socklen_t addrlen = test_addr_dual.addr_len;
- ASSERT_THAT(getsockname(fd_dual.get(),
- reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
@@ -2003,8 +2090,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
const FileDescriptor fd_v6 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len),
+ ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 loopback on the same port with a v6 socket
@@ -2015,10 +2101,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) {
SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port));
const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v4_mapped.family(), param.type, 0));
- ASSERT_THAT(
- bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
- test_addr_v4_mapped.addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
+ ASSERT_THAT(bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 loopback on the same port with a v4 socket
// fails.
@@ -2027,8 +2112,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
const FileDescriptor fd_v4 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
- test_addr_v4.addr_len),
+ ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 any on the same port with a v4 socket
@@ -2038,7 +2122,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v4_any.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any),
test_addr_v4_any.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -2055,14 +2139,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
- test_addr_dual.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len),
+ SyscallSucceeds());
// Get the port that we bound.
socklen_t addrlen = test_addr_dual.addr_len;
- ASSERT_THAT(getsockname(fd_dual.get(),
- reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
@@ -2076,7 +2159,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any),
test_addr_v4_any.addr_len),
SyscallSucceeds());
}
@@ -2096,16 +2179,15 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
- test_addr_dual.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds());
// Get the port that we bound.
socklen_t addrlen = test_addr_dual.addr_len;
- ASSERT_THAT(getsockname(fd_dual.get(),
- reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
@@ -2120,7 +2202,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any),
test_addr_v4_any.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -2137,16 +2219,15 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
sockaddr_storage addr_dual = test_addr_dual.addr;
const FileDescriptor fd_dual =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
- ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
- test_addr_dual.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds());
// Get the port that we bound.
socklen_t addrlen = test_addr_dual.addr_len;
- ASSERT_THAT(getsockname(fd_dual.get(),
- reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
@@ -2157,8 +2238,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
const FileDescriptor fd_v6 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len),
+ ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 loopback on the same port with a v6 socket
@@ -2169,10 +2249,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port));
const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v4_mapped.family(), param.type, 0));
- ASSERT_THAT(
- bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
- test_addr_v4_mapped.addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
+ ASSERT_THAT(bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 loopback on the same port with a v4 socket
// fails.
@@ -2181,8 +2260,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
const FileDescriptor fd_v4 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
- test_addr_v4.addr_len),
+ ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 any on the same port with a v4 socket
@@ -2192,7 +2270,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v4_any.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any),
test_addr_v4_any.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -2209,14 +2287,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) {
EXPECT_THAT(setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
- test_addr_dual.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len),
+ SyscallSucceeds());
// Get the port that we bound.
socklen_t addrlen = test_addr_dual.addr_len;
- ASSERT_THAT(getsockname(fd_dual.get(),
- reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
@@ -2227,8 +2304,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) {
ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
const FileDescriptor fd_v6 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len),
+ ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Verify that we can still bind the v4 loopback on the same port.
@@ -2238,9 +2314,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) {
SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port));
const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v4_mapped.family(), param.type, 0));
- int ret =
- bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
- test_addr_v4_mapped.addr_len);
+ int ret = bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len);
if (ret == -1 && errno == EADDRINUSE) {
// Port may have been in use.
ASSERT_LT(i, 100); // Give up after 100 tries.
@@ -2262,9 +2337,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
sockaddr_storage bound_addr = test_addr.addr;
const FileDescriptor bound_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
+ SyscallSucceeds());
// Listen iff TCP.
if (param.type == SOCK_STREAM) {
@@ -2274,23 +2349,20 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
// Connect to bind an ephemeral port.
const FileDescriptor connected_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr),
bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
socklen_t connected_addr_len = sizeof(connected_addr);
- ASSERT_THAT(getsockname(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&connected_addr),
+ ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr),
&connected_addr_len),
SyscallSucceeds());
uint16_t const ephemeral_port =
@@ -2302,10 +2374,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
// Verify that the ephemeral port is reserved.
const FileDescriptor checking_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- EXPECT_THAT(
- bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
- connected_addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
+ EXPECT_THAT(bind(checking_fd.get(), AsSockAddr(&connected_addr),
+ connected_addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v6 loopback with the same port fails.
TestAddress const& test_addr_v6 = V6Loopback();
@@ -2314,8 +2385,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port));
const FileDescriptor fd_v6 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len),
+ ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Verify that we can still bind the v4 loopback on the same port.
@@ -2325,9 +2395,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
ephemeral_port));
const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v4_mapped.family(), param.type, 0));
- int ret =
- bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
- test_addr_v4_mapped.addr_len);
+ int ret = bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len);
if (ret == -1 && errno == EADDRINUSE) {
// Port may have been in use.
ASSERT_LT(i, 100); // Give up after 100 tries.
@@ -2348,8 +2417,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) {
sockaddr_storage bound_addr = test_addr.addr;
const FileDescriptor bound_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
+ ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
SyscallSucceeds());
ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
@@ -2363,8 +2431,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) {
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
// Connect to bind an ephemeral port.
@@ -2373,16 +2440,14 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) {
ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr),
bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
socklen_t connected_addr_len = sizeof(connected_addr);
- ASSERT_THAT(getsockname(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&connected_addr),
+ ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr),
&connected_addr_len),
SyscallSucceeds());
uint16_t const ephemeral_port =
@@ -2398,8 +2463,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) {
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
EXPECT_THAT(
- bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
- connected_addr_len),
+ bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len),
SyscallSucceeds());
}
@@ -2412,9 +2476,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
sockaddr_storage bound_addr = test_addr.addr;
const FileDescriptor bound_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
+ SyscallSucceeds());
// Listen iff TCP.
if (param.type == SOCK_STREAM) {
@@ -2424,23 +2488,20 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
// Connect to bind an ephemeral port.
const FileDescriptor connected_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr),
bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
socklen_t connected_addr_len = sizeof(connected_addr);
- ASSERT_THAT(getsockname(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&connected_addr),
+ ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr),
&connected_addr_len),
SyscallSucceeds());
uint16_t const ephemeral_port =
@@ -2452,10 +2513,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
// Verify that the ephemeral port is reserved.
const FileDescriptor checking_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- EXPECT_THAT(
- bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
- connected_addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
+ EXPECT_THAT(bind(checking_fd.get(), AsSockAddr(&connected_addr),
+ connected_addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 loopback on the same port with a v4 socket
// fails.
@@ -2465,8 +2525,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
SetAddrPort(test_addr_v4.family(), &addr_v4, ephemeral_port));
const FileDescriptor fd_v4 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
- EXPECT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
- test_addr_v4.addr_len),
+ EXPECT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v6 any on the same port with a dual-stack socket
@@ -2477,7 +2536,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port));
const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v6_any.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
+ ASSERT_THAT(bind(fd_v6_any.get(), AsSockAddr(&addr_v6_any),
test_addr_v6_any.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
@@ -2496,8 +2555,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port));
const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v6.family(), param.type, 0));
- ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len);
+ ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len);
} else {
// Verify that we can still bind the v6 any on the same port with a
// v6-only socket.
@@ -2506,9 +2564,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ret =
- bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
- test_addr_v6_any.addr_len);
+ ret = bind(fd_v6_only_any.get(), AsSockAddr(&addr_v6_any),
+ test_addr_v6_any.addr_len);
}
if (ret == -1 && errno == EADDRINUSE) {
@@ -2532,8 +2589,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
sockaddr_storage bound_addr = test_addr.addr;
const FileDescriptor bound_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
+ ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
SyscallSucceeds());
ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
@@ -2548,8 +2604,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
// Connect to bind an ephemeral port.
@@ -2558,16 +2613,14 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr),
bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
socklen_t connected_addr_len = sizeof(connected_addr);
- ASSERT_THAT(getsockname(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&connected_addr),
+ ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr),
&connected_addr_len),
SyscallSucceeds());
uint16_t const ephemeral_port =
@@ -2583,8 +2636,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
EXPECT_THAT(
- bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
- connected_addr_len),
+ bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len),
SyscallSucceeds());
}
@@ -2597,9 +2649,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
sockaddr_storage bound_addr = test_addr.addr;
const FileDescriptor bound_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
+ SyscallSucceeds());
// Listen iff TCP.
if (param.type == SOCK_STREAM) {
@@ -2609,23 +2661,20 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
// Connect to bind an ephemeral port.
const FileDescriptor connected_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr),
bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
socklen_t connected_addr_len = sizeof(connected_addr);
- ASSERT_THAT(getsockname(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&connected_addr),
+ ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr),
&connected_addr_len),
SyscallSucceeds());
uint16_t const ephemeral_port =
@@ -2637,10 +2686,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
// Verify that the ephemeral port is reserved.
const FileDescriptor checking_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- EXPECT_THAT(
- bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
- connected_addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
+ EXPECT_THAT(bind(checking_fd.get(), AsSockAddr(&connected_addr),
+ connected_addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v4 loopback on the same port with a v6 socket
// fails.
@@ -2650,10 +2698,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
ephemeral_port));
const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v4_mapped.family(), param.type, 0));
- EXPECT_THAT(
- bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
- test_addr_v4_mapped.addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
+ EXPECT_THAT(bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
// Verify that binding the v6 any on the same port with a dual-stack socket
// fails.
@@ -2663,7 +2710,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port));
const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v6_any.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
+ ASSERT_THAT(bind(fd_v6_any.get(), AsSockAddr(&addr_v6_any),
test_addr_v6_any.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
@@ -2682,8 +2729,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port));
const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v6.family(), param.type, 0));
- ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
- test_addr_v6.addr_len);
+ ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len);
} else {
// Verify that we can still bind the v6 any on the same port with a
// v6-only socket.
@@ -2692,9 +2738,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ret =
- bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
- test_addr_v6_any.addr_len);
+ ret = bind(fd_v6_only_any.get(), AsSockAddr(&addr_v6_any),
+ test_addr_v6_any.addr_len);
}
if (ret == -1 && errno == EADDRINUSE) {
@@ -2722,8 +2767,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
+ ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
SyscallSucceeds());
// Listen iff TCP.
@@ -2734,8 +2778,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
// Connect to bind an ephemeral port.
@@ -2746,16 +2789,14 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&bound_addr),
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr),
bound_addr_len),
SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
socklen_t connected_addr_len = sizeof(connected_addr);
- ASSERT_THAT(getsockname(connected_fd.get(),
- reinterpret_cast<sockaddr*>(&connected_addr),
+ ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr),
&connected_addr_len),
SyscallSucceeds());
uint16_t const ephemeral_port =
@@ -2771,8 +2812,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
EXPECT_THAT(
- bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
- connected_addr_len),
+ bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len),
SyscallSucceeds());
}
@@ -2791,14 +2831,12 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
+ ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
SyscallSucceeds());
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
// Now create a socket and bind it to the same port, this should
@@ -2809,9 +2847,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
ASSERT_THAT(setsockopt(second_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(second_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(second_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len),
+ SyscallSucceeds());
}
TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
@@ -2830,10 +2868,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd1, reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd1, AsSockAddr(&addr), addrlen), SyscallSucceeds());
- ASSERT_THAT(getsockname(fd1, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ ASSERT_THAT(getsockname(fd1, AsSockAddr(&addr), &addrlen),
SyscallSucceeds());
if (param.type == SOCK_STREAM) {
ASSERT_THAT(listen(fd1, 1), SyscallSucceeds());
@@ -2852,7 +2889,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
SyscallSucceeds());
std::cout << portreuse1 << " " << portreuse2 << std::endl;
- int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen);
+ int ret = bind(fd2, AsSockAddr(&addr), addrlen);
// Verify that two sockets can be bound to the same port only if
// SO_REUSEPORT is set for both of them.
@@ -2880,10 +2917,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) {
ASSERT_THAT(
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
- ASSERT_THAT(getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd, AsSockAddr(&addr), addrlen), SyscallSucceeds());
+ ASSERT_THAT(getsockname(fd, AsSockAddr(&addr), &addrlen), SyscallSucceeds());
ASSERT_EQ(addrlen, test_addr.addr_len);
s.reset();
@@ -2895,8 +2930,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) {
ASSERT_THAT(
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)),
SyscallSucceeds());
- ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(fd, AsSockAddr(&addr), addrlen), SyscallSucceeds());
}
INSTANTIATE_TEST_SUITE_P(
diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
index 1a0b53394..601ae107b 100644
--- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
+++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
@@ -86,7 +86,7 @@ using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>;
// We disable S/R because this test creates a large number of sockets.
//
// FIXME(b/162475855): This test is failing reliably.
-TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
@@ -98,15 +98,14 @@ TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) {
auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(getsockname(listen_fd.get(),
- reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen),
SyscallSucceeds());
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
@@ -124,8 +123,7 @@ TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) {
for (int i = 0; i < kClients; i++) {
FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
- int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
+ int ret = connect(client.get(), AsSockAddr(&conn_addr), connector.addr_len);
if (ret == 0) {
clients.push_back(std::move(client));
FileDescriptor server =
@@ -181,8 +179,7 @@ std::string DescribeProtocolTestParam(
using SocketMultiProtocolInetLoopbackTest =
::testing::TestWithParam<ProtocolTestParam>;
-TEST_P(SocketMultiProtocolInetLoopbackTest,
- BindAvoidsListeningPortsReuseAddr_NoRandomSave) {
+TEST_P(SocketMultiProtocolInetLoopbackTest, BindAvoidsListeningPortsReuseAddr) {
const auto& param = GetParam();
// UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP
// this is only permitted if there is no other listening socket.
@@ -205,8 +202,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- int ret = bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- test_addr.addr_len);
+ int ret = bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len);
if (ret != 0) {
ASSERT_EQ(errno, EADDRINUSE);
break;
@@ -214,8 +210,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest,
// Get the port that we bound.
socklen_t bound_addr_len = test_addr.addr_len;
ASSERT_THAT(
- getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- &bound_addr_len),
+ getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len),
SyscallSucceeds());
uint16_t port = reinterpret_cast<sockaddr_in*>(&bound_addr)->sin_port;
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index f10f55b27..59b56dc1a 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -1153,7 +1153,7 @@ TEST_P(TCPSocketPairTest, IpMulticastLoopDefault) {
EXPECT_EQ(get, 1);
}
-TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) {
+TEST_P(TCPSocketPairTest, TCPResetDuringClose) {
DisableSave ds; // Too many syscalls.
constexpr int kThreadCount = 1000;
std::unique_ptr<ScopedThread> instances[kThreadCount];
diff --git a/test/syscalls/linux/socket_ip_unbound_netlink.cc b/test/syscalls/linux/socket_ip_unbound_netlink.cc
index 7fb1c0faf..b02222999 100644
--- a/test/syscalls/linux/socket_ip_unbound_netlink.cc
+++ b/test/syscalls/linux/socket_ip_unbound_netlink.cc
@@ -35,7 +35,7 @@ namespace testing {
// Test fixture for tests that apply to pairs of IP sockets.
using IPv6UnboundSocketTest = SimpleSocketTest;
-TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) {
+TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
// TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved
@@ -57,8 +57,7 @@ TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) {
TestAddress addr = V6Loopback();
reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = 65535;
auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ EXPECT_THAT(connect(sock->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(EADDRNOTAVAIL));
}
@@ -69,7 +68,7 @@ INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv6UnboundSocketTest,
using IPv4UnboundSocketTest = SimpleSocketTest;
-TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) {
+TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
// TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved
@@ -90,8 +89,7 @@ TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) {
TestAddress addr = V4Loopback();
reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = 65535;
auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ EXPECT_THAT(connect(sock->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(ENETUNREACH));
}
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index 8eec31a46..18be4dcc7 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -44,20 +44,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
EXPECT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address. If multicast worked like unicast,
// this would ensure that we get the packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -68,10 +65,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -83,19 +80,19 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
// Check that not setting a default send interface prevents multicast packets
// from being sent. Group membership interface configured by address.
TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) {
+ // TODO(b/185517803): Fix for native test.
+ SKIP_IF(!IsRunningOnGvisor());
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the second FD to the v4 any address to ensure that we can receive any
// unicast packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -114,28 +111,28 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallFailsWithErrno(ENETUNREACH));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallFailsWithErrno(ENETUNREACH));
}
// Check that not setting a default send interface prevents multicast packets
// from being sent. Group membership interface configured by NIC ID.
TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) {
+ // TODO(b/185517803): Fix for native test.
+ SKIP_IF(!IsRunningOnGvisor());
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the second FD to the v4 any address to ensure that we can receive any
// unicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -154,10 +151,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallFailsWithErrno(ENETUNREACH));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallFailsWithErrno(ENETUNREACH));
}
// Check that multicast works when the default send interface is configured by
@@ -170,20 +167,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) {
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -202,10 +196,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -226,20 +220,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -258,10 +249,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -289,13 +280,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) {
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -314,10 +303,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -345,13 +334,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -370,10 +357,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -401,13 +388,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) {
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -425,8 +410,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
ASSERT_THAT(
- RetryEINTR(connect)(socket1->get(),
- reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
@@ -461,13 +445,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -485,8 +467,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
ASSERT_THAT(
- RetryEINTR(connect)(socket1->get(),
- reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
@@ -521,13 +502,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) {
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -546,10 +525,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -577,13 +556,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -602,10 +579,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -633,13 +610,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) {
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -657,8 +632,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
EXPECT_THAT(
- RetryEINTR(connect)(socket1->get(),
- reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
@@ -691,13 +665,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -715,8 +687,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
ASSERT_THAT(
- RetryEINTR(connect)(socket1->get(),
- reinterpret_cast<sockaddr*>(&connect_addr.addr),
+ RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
@@ -753,13 +724,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -778,10 +747,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -813,13 +782,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -838,10 +805,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -877,20 +844,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) {
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
EXPECT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -912,10 +876,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -935,20 +899,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
EXPECT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -970,10 +931,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -1194,6 +1155,8 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) {
}
TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) {
+ // TODO(b/185517803): Fix for native test.
+ SKIP_IF(!IsRunningOnGvisor());
auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -1292,16 +1255,15 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) {
ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
&group, sizeof(group)),
SyscallSucceeds());
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(bind(sockets->second_fd(), AsSockAddr(&receiver_addr.addr),
receiver_addr.addr_len),
SyscallSucceeds());
// Get the port assigned.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- &receiver_addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ getsockname(sockets->second_fd(), AsSockAddr(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
// On the first iteration, save the port we are bound to. On the second
// iteration, verify the port is the same as the one from the first
@@ -1324,8 +1286,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) {
RandomizeBuffer(send_buf, sizeof(send_buf));
ASSERT_THAT(
RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet on both sockets.
@@ -1367,16 +1328,15 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
&group, sizeof(group)),
SyscallSucceeds());
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(bind(sockets->second_fd(), AsSockAddr(&receiver_addr.addr),
receiver_addr.addr_len),
SyscallSucceeds());
// Get the port assigned.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- &receiver_addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ getsockname(sockets->second_fd(), AsSockAddr(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
// On the first iteration, save the port we are bound to. On the second
// iteration, verify the port is the same as the one from the first
@@ -1403,8 +1363,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
RandomizeBuffer(send_buf, sizeof(send_buf));
ASSERT_THAT(
RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet on both sockets.
@@ -1427,8 +1386,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
char send_buf[200];
ASSERT_THAT(
RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
SyscallSucceedsWithValue(sizeof(send_buf)));
char recv_buf[sizeof(send_buf)] = {};
@@ -1448,14 +1406,12 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
// Bind second socket (receiver) to the multicast address.
auto receiver_addr = V4Multicast();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Update receiver_addr with the correct port number.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -1479,10 +1435,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -1500,14 +1456,12 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) {
// Bind second socket (receiver) to the multicast address.
auto receiver_addr = V4Multicast();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Update receiver_addr with the correct port number.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -1523,10 +1477,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we don't receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -1543,13 +1497,11 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
// Bind second socket (receiver) to the ANY address.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -1557,12 +1509,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
// Bind the first socket (sender) to the multicast address.
auto sender_addr = V4Multicast();
ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
socklen_t sender_addr_len = sender_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&sender_addr.addr),
&sender_addr_len),
SyscallSucceeds());
EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
@@ -1573,10 +1523,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -1594,13 +1544,11 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) {
// Bind second socket (receiver) to the broadcast address.
auto receiver_addr = V4Broadcast();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -1611,19 +1559,18 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) {
SyscallSucceedsWithValue(0));
// Note: Binding to the loopback interface makes the broadcast go out of it.
auto sender_bind_addr = V4Loopback();
- ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_bind_addr.addr),
- sender_bind_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&sender_bind_addr.addr),
+ sender_bind_addr.addr_len),
+ SyscallSucceeds());
auto sendto_addr = V4Broadcast();
reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -1641,13 +1588,11 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
// Bind second socket (receiver) to the ANY address.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(socket2->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -1655,12 +1600,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
// Bind the first socket (sender) to the broadcast address.
auto sender_addr = V4Broadcast();
ASSERT_THAT(
- bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
socklen_t sender_addr_len = sender_addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&sender_addr.addr),
&sender_addr_len),
SyscallSucceeds());
EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
@@ -1671,10 +1614,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -1688,7 +1631,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
//
// FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. Enable
// random and co-op save (below) once that is fixed.
-TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
+TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) {
std::vector<std::unique_ptr<FileDescriptor>> sockets;
sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()));
@@ -1698,12 +1641,10 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(sockets[0]->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(sockets[0]->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(sockets[0]->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(sockets[0]->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -1719,8 +1660,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
ASSERT_THAT(setsockopt(last->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(last->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(last->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
// Send a new message to the SO_REUSEADDR group. We use a new socket each
@@ -1730,8 +1670,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
char send_buf[kMessageSize];
RandomizeBuffer(send_buf, sizeof(send_buf));
EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceedsWithValue(sizeof(send_buf)));
// Verify that the most recent socket got the message. We don't expect any
@@ -1763,12 +1702,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -1776,8 +1713,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) {
ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -1792,12 +1728,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -1805,8 +1739,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) {
ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -1825,12 +1758,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -1838,16 +1769,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) {
ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
// Bind socket3 to the same address as socket1, only with REUSEADDR.
ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -1866,12 +1795,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -1879,16 +1806,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) {
ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
// Bind socket3 to the same address as socket1, only with REUSEPORT.
ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
}
@@ -1907,12 +1832,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -1920,8 +1843,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) {
ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
// Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT.
@@ -1931,8 +1853,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) {
ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
}
@@ -1951,12 +1872,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -1964,8 +1883,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) {
ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
// Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT.
@@ -1975,8 +1893,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) {
ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
}
@@ -1995,12 +1912,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -2013,16 +1928,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) {
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
// Bind socket3 to the same address as socket1, only with REUSEPORT.
ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
}
@@ -2041,12 +1954,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(socket1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -2059,16 +1970,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
// Bind socket3 to the same address as socket1, only with REUSEADDR.
ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
}
@@ -2086,12 +1995,10 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(receiver1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(receiver1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(receiver1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -2103,8 +2010,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEPORT,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(bind(receiver2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(receiver2->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
constexpr int kMessageSize = 10;
@@ -2119,8 +2025,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
char send_buf[kMessageSize] = {};
EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceedsWithValue(sizeof(send_buf)));
}
@@ -2149,13 +2054,11 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
int level = SOL_IP;
int type = IP_PKTINFO;
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
socklen_t sender_addr_len = sender_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&sender_addr.addr),
&sender_addr_len),
SyscallSucceeds());
EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
@@ -2163,10 +2066,9 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
auto receiver_addr = V4Loopback();
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port;
- ASSERT_THAT(
- connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(connect(sender->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Allow socket to receive control message.
ASSERT_THAT(
@@ -2230,29 +2132,25 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) {
int level = SOL_IP;
int type = IP_RECVORIGDSTADDR;
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Retrieve the port bound by the receiver.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
- ASSERT_THAT(
- connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(connect(sender->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Get address and port bound by the sender.
sockaddr_storage sender_addr_storage;
socklen_t sender_addr_len = sizeof(sender_addr_storage);
- ASSERT_THAT(getsockname(sender->get(),
- reinterpret_cast<sockaddr*>(&sender_addr_storage),
+ ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&sender_addr_storage),
&sender_addr_len),
SyscallSucceeds());
ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in));
@@ -2407,9 +2305,7 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBuf) {
SyscallSucceeds());
// Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
- if (!IsRunningOnGvisor()) {
- quarter_sz *= 2;
- }
+ quarter_sz *= 2;
ASSERT_EQ(quarter_sz, val);
}
@@ -2524,22 +2420,19 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) {
// Bind the first FD to the loopback. This is an alternative to
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
- ASSERT_THAT(
- bind(sender_socket->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(sender_socket->get(), AsSockAddr(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(receiver_socket->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(bind(receiver_socket->get(), AsSockAddr(&receiver_addr.addr),
receiver_addr.addr_len),
SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
ASSERT_THAT(getsockname(receiver_socket->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- &receiver_addr_len),
+ AsSockAddr(&receiver_addr.addr), &receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -2565,8 +2458,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) {
RandomizeBuffer(send_buf, sizeof(send_buf));
ASSERT_THAT(
RetryEINTR(sendto)(sender_socket->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index 940289d15..c6e775b2a 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -50,38 +50,35 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the first socket to the ANY address and let the system assign a port.
auto rcv1_addr = V4Any();
- ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
- rcv1_addr.addr_len),
- SyscallSucceedsWithValue(0));
+ ASSERT_THAT(
+ bind(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), rcv1_addr.addr_len),
+ SyscallSucceedsWithValue(0));
// Retrieve port number from first socket so that it can be bound to the
// second socket.
socklen_t rcv_addr_sz = rcv1_addr.addr_len;
ASSERT_THAT(
- getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
- &rcv_addr_sz),
+ getsockname(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), &rcv_addr_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len);
auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port;
// Bind the second socket to the same address:port as the first.
- ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
- rcv_addr_sz),
+ ASSERT_THAT(bind(rcvr2->get(), AsSockAddr(&rcv1_addr.addr), rcv_addr_sz),
SyscallSucceedsWithValue(0));
// Bind the non-receiving socket to an ephemeral port.
auto norecv_addr = V4Any();
- ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
- norecv_addr.addr_len),
- SyscallSucceedsWithValue(0));
+ ASSERT_THAT(
+ bind(norcv->get(), AsSockAddr(&norecv_addr.addr), norecv_addr.addr_len),
+ SyscallSucceedsWithValue(0));
// Broadcast a test message.
auto dst_addr = V4Broadcast();
reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port;
constexpr char kTestMsg[] = "hello, world";
- EXPECT_THAT(
- sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ AsSockAddr(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
// Verify that the receiving sockets received the test message.
char buf[sizeof(kTestMsg)] = {};
@@ -130,15 +127,14 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the first socket the ANY address and let the system assign a port.
auto rcv1_addr = V4Any();
- ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
- rcv1_addr.addr_len),
- SyscallSucceedsWithValue(0));
+ ASSERT_THAT(
+ bind(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), rcv1_addr.addr_len),
+ SyscallSucceedsWithValue(0));
// Retrieve port number from first socket so that it can be bound to the
// second socket.
socklen_t rcv_addr_sz = rcv1_addr.addr_len;
ASSERT_THAT(
- getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr),
- &rcv_addr_sz),
+ getsockname(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), &rcv_addr_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len);
auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port;
@@ -146,26 +142,25 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the second socket to the broadcast address.
auto rcv2_addr = V4Broadcast();
reinterpret_cast<sockaddr_in*>(&rcv2_addr.addr)->sin_port = port;
- ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv2_addr.addr),
- rcv2_addr.addr_len),
- SyscallSucceedsWithValue(0));
+ ASSERT_THAT(
+ bind(rcvr2->get(), AsSockAddr(&rcv2_addr.addr), rcv2_addr.addr_len),
+ SyscallSucceedsWithValue(0));
// Bind the non-receiving socket to the unicast ethernet address.
auto norecv_addr = rcv1_addr;
reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr =
eth_if_addr_.sin_addr;
- ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
- norecv_addr.addr_len),
- SyscallSucceedsWithValue(0));
+ ASSERT_THAT(
+ bind(norcv->get(), AsSockAddr(&norecv_addr.addr), norecv_addr.addr_len),
+ SyscallSucceedsWithValue(0));
// Broadcast a test message.
auto dst_addr = V4Broadcast();
reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port;
constexpr char kTestMsg[] = "hello, world";
- EXPECT_THAT(
- sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ AsSockAddr(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
// Verify that the receiving sockets received the test message.
char buf[sizeof(kTestMsg)] = {};
@@ -199,12 +194,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the sender to the broadcast address.
auto src_addr = V4Broadcast();
- ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr),
- src_addr.addr_len),
- SyscallSucceedsWithValue(0));
+ ASSERT_THAT(
+ bind(sender->get(), AsSockAddr(&src_addr.addr), src_addr.addr_len),
+ SyscallSucceedsWithValue(0));
socklen_t src_sz = src_addr.addr_len;
- ASSERT_THAT(getsockname(sender->get(),
- reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz),
+ ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&src_addr.addr), &src_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(src_sz, src_addr.addr_len);
@@ -213,10 +207,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port;
constexpr char kTestMsg[] = "hello, world";
- EXPECT_THAT(
- sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ AsSockAddr(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
// Verify that the message was received.
char buf[sizeof(kTestMsg)] = {};
@@ -241,12 +234,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the sender to the ANY address.
auto src_addr = V4Any();
- ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr),
- src_addr.addr_len),
- SyscallSucceedsWithValue(0));
+ ASSERT_THAT(
+ bind(sender->get(), AsSockAddr(&src_addr.addr), src_addr.addr_len),
+ SyscallSucceedsWithValue(0));
socklen_t src_sz = src_addr.addr_len;
- ASSERT_THAT(getsockname(sender->get(),
- reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz),
+ ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&src_addr.addr), &src_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(src_sz, src_addr.addr_len);
@@ -255,10 +247,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port;
constexpr char kTestMsg[] = "hello, world";
- EXPECT_THAT(
- sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(kTestMsg)));
+ EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
+ AsSockAddr(&dst_addr.addr), dst_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(kTestMsg)));
// Verify that the message was received.
char buf[sizeof(kTestMsg)] = {};
@@ -280,7 +271,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) {
constexpr char kTestMsg[] = "hello, world";
EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len),
+ AsSockAddr(&addr.addr), addr.addr_len),
SyscallFailsWithErrno(EACCES));
}
@@ -294,19 +285,17 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) {
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = htons(0);
- ASSERT_THAT(bind(rcvr->get(), reinterpret_cast<struct sockaddr*>(&addr),
- sizeof(addr)),
+ ASSERT_THAT(bind(rcvr->get(), AsSockAddr(&addr), sizeof(addr)),
SyscallSucceedsWithValue(0));
memset(&addr, 0, sizeof(addr));
socklen_t addr_sz = sizeof(addr);
- ASSERT_THAT(getsockname(rcvr->get(),
- reinterpret_cast<struct sockaddr*>(&addr), &addr_sz),
+ ASSERT_THAT(getsockname(rcvr->get(), AsSockAddr(&addr), &addr_sz),
SyscallSucceedsWithValue(0));
// Send a test message to the receiver.
constexpr char kTestMsg[] = "hello, world";
ASSERT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0,
- reinterpret_cast<struct sockaddr*>(&addr), addr_sz),
+ AsSockAddr(&addr), addr_sz),
SyscallSucceedsWithValue(sizeof(kTestMsg)));
char buf[sizeof(kTestMsg)] = {};
ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0),
@@ -326,13 +315,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
- ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
- bind_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket->get(), AsSockAddr(&bind_addr.addr), bind_addr.addr_len),
+ SyscallSucceeds());
socklen_t bind_addr_len = bind_addr.addr_len;
ASSERT_THAT(
- getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
- &bind_addr_len),
+ getsockname(socket->get(), AsSockAddr(&bind_addr.addr), &bind_addr_len),
SyscallSucceeds());
EXPECT_EQ(bind_addr_len, bind_addr.addr_len);
@@ -342,10 +330,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -361,13 +349,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) {
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
- ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
- bind_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket->get(), AsSockAddr(&bind_addr.addr), bind_addr.addr_len),
+ SyscallSucceeds());
socklen_t bind_addr_len = bind_addr.addr_len;
ASSERT_THAT(
- getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
- &bind_addr_len),
+ getsockname(socket->get(), AsSockAddr(&bind_addr.addr), &bind_addr_len),
SyscallSucceeds());
EXPECT_EQ(bind_addr_len, bind_addr.addr_len);
@@ -384,10 +371,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) {
reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -405,13 +392,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
- ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
- bind_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket->get(), AsSockAddr(&bind_addr.addr), bind_addr.addr_len),
+ SyscallSucceeds());
socklen_t bind_addr_len = bind_addr.addr_len;
ASSERT_THAT(
- getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr),
- &bind_addr_len),
+ getsockname(socket->get(), AsSockAddr(&bind_addr.addr), &bind_addr_len),
SyscallSucceeds());
EXPECT_EQ(bind_addr_len, bind_addr.addr_len);
@@ -433,10 +419,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -460,13 +446,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) {
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -477,10 +461,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -499,13 +483,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) {
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -523,10 +505,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -547,13 +529,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -576,10 +556,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -600,13 +580,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -629,10 +607,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -661,13 +639,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
// Bind to ANY to receive multicast packets.
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -696,10 +672,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
for (auto& receiver : receivers) {
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
@@ -727,13 +703,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -765,10 +739,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
for (auto& receiver : receivers) {
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
@@ -798,13 +772,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT,
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -840,10 +812,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
for (auto& receiver : receivers) {
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
@@ -863,13 +835,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -887,15 +857,13 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// receiver side).
auto sendto_addr = V4Multicast();
reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port;
- ASSERT_THAT(RetryEINTR(connect)(
- sender->get(), reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
+ ASSERT_THAT(RetryEINTR(connect)(sender->get(), AsSockAddr(&sendto_addr.addr),
+ sendto_addr.addr_len),
SyscallSucceeds());
auto sender_addr = V4EmptyAddress();
- ASSERT_THAT(
- getsockname(sender->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- &sender_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&sender_addr.addr),
+ &sender_addr.addr_len),
+ SyscallSucceeds());
ASSERT_EQ(sizeof(struct sockaddr_in), sender_addr.addr_len);
sockaddr_in* sender_addr_in =
reinterpret_cast<sockaddr_in*>(&sender_addr.addr);
@@ -910,8 +878,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto src_addr = V4EmptyAddress();
ASSERT_THAT(
RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0,
- reinterpret_cast<sockaddr*>(&src_addr.addr),
- &src_addr.addr_len),
+ AsSockAddr(&src_addr.addr), &src_addr.addr_len),
SyscallSucceedsWithValue(sizeof(recv_buf)));
ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len);
sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr);
@@ -931,13 +898,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Create receiver, bind to ANY and join the multicast group.
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver_addr = V4Any();
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -964,18 +929,17 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto sendto_addr = V4Multicast();
reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port;
char send_buf[4] = {};
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Receive a multicast packet.
char recv_buf[sizeof(send_buf)] = {};
auto src_addr = V4EmptyAddress();
ASSERT_THAT(
RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0,
- reinterpret_cast<sockaddr*>(&src_addr.addr),
- &src_addr.addr_len),
+ AsSockAddr(&src_addr.addr), &src_addr.addr_len),
SyscallSucceedsWithValue(sizeof(recv_buf)));
ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len);
sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr);
@@ -1000,9 +964,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Create sender and bind to eth interface.
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&eth_if_addr_),
- sizeof(eth_if_addr_)),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(sender->get(), AsSockAddr(&eth_if_addr_), sizeof(eth_if_addr_)),
+ SyscallSucceeds());
// Run through all possible combinations of index and address for
// IP_MULTICAST_IF that selects the loopback interface.
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc
index bcbd2feac..7ca6d52e4 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc
@@ -29,18 +29,15 @@ using IPv4UDPUnboundSocketNogotsanTest = SimpleSocketTest;
// Check that connect returns EAGAIN when out of local ephemeral ports.
// We disable S/R because this test creates a large number of sockets.
-TEST_P(IPv4UDPUnboundSocketNogotsanTest,
- UDPConnectPortExhaustion_NoRandomSave) {
+TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPConnectPortExhaustion) {
auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
constexpr int kClients = 65536;
// Bind the first socket to the loopback and take note of the selected port.
auto addr = V4Loopback();
- ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len),
+ ASSERT_THAT(bind(receiver1->get(), AsSockAddr(&addr.addr), addr.addr_len),
SyscallSucceeds());
socklen_t addr_len = addr.addr_len;
- ASSERT_THAT(getsockname(receiver1->get(),
- reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ ASSERT_THAT(getsockname(receiver1->get(), AsSockAddr(&addr.addr), &addr_len),
SyscallSucceeds());
EXPECT_EQ(addr_len, addr.addr_len);
@@ -50,8 +47,7 @@ TEST_P(IPv4UDPUnboundSocketNogotsanTest,
for (int i = 0; i < kClients; i++) {
auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr),
- addr.addr_len);
+ int ret = connect(s->get(), AsSockAddr(&addr.addr), addr.addr_len);
if (ret == 0) {
sockets.push_back(std::move(s));
continue;
@@ -63,7 +59,7 @@ TEST_P(IPv4UDPUnboundSocketNogotsanTest,
// Check that bind returns EADDRINUSE when out of local ephemeral ports.
// We disable S/R because this test creates a large number of sockets.
-TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion_NoRandomSave) {
+TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion) {
auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
constexpr int kClients = 65536;
auto addr = V4Loopback();
@@ -73,8 +69,7 @@ TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion_NoRandomSave) {
for (int i = 0; i < kClients; i++) {
auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- int ret =
- bind(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len);
+ int ret = bind(s->get(), AsSockAddr(&addr.addr), addr.addr_len);
if (ret == 0) {
sockets.push_back(std::move(s));
continue;
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
index 9a9ddc297..020ce5d6e 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
@@ -56,10 +56,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.2",
&(reinterpret_cast<sockaddr_in*>(&sender_addr.addr)
->sin_addr.s_addr)));
- ASSERT_THAT(
- bind(snd_sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(snd_sock->get(), AsSockAddr(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
// Send the packet to an unassigned address but an address that is in the
// subnet associated with the loopback interface.
@@ -69,23 +68,20 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.254",
&(reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)
->sin_addr.s_addr)));
- ASSERT_THAT(
- bind(rcv_sock->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(rcv_sock->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(rcv_sock->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(rcv_sock->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len);
char send_buf[kSendBufSize];
RandomizeBuffer(send_buf, kSendBufSize);
- ASSERT_THAT(
- RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0,
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceedsWithValue(kSendBufSize));
+ ASSERT_THAT(RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0,
+ AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceedsWithValue(kSendBufSize));
// Check that we received the packet.
char recv_buf[kSendBufSize] = {};
@@ -155,14 +151,12 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) {
<< "socks[" << idx << "]";
if (bind_wildcard) {
- ASSERT_THAT(
- bind(sock->get(), reinterpret_cast<sockaddr*>(&any_address.addr),
- any_address.addr_len),
- SyscallSucceeds())
+ ASSERT_THAT(bind(sock->get(), AsSockAddr(&any_address.addr),
+ any_address.addr_len),
+ SyscallSucceeds())
<< "socks[" << idx << "]";
} else {
- ASSERT_THAT(bind(sock->get(),
- reinterpret_cast<sockaddr*>(&broadcast_address.addr),
+ ASSERT_THAT(bind(sock->get(), AsSockAddr(&broadcast_address.addr),
broadcast_address.addr_len),
SyscallSucceeds())
<< "socks[" << idx << "]";
@@ -177,17 +171,16 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) {
// Broadcasts from each socket should be received by every socket (including
// the sending socket).
- for (long unsigned int w = 0; w < socks.size(); w++) {
+ for (size_t w = 0; w < socks.size(); w++) {
auto& w_sock = socks[w];
- ASSERT_THAT(
- RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0,
- reinterpret_cast<sockaddr*>(&broadcast_address.addr),
- broadcast_address.addr_len),
- SyscallSucceedsWithValue(kSendBufSize))
+ ASSERT_THAT(RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0,
+ AsSockAddr(&broadcast_address.addr),
+ broadcast_address.addr_len),
+ SyscallSucceedsWithValue(kSendBufSize))
<< "write socks[" << w << "]";
// Check that we received the packet on all sockets.
- for (long unsigned int r = 0; r < socks.size(); r++) {
+ for (size_t r = 0; r < socks.size(); r++) {
auto& r_sock = socks[r];
struct pollfd poll_fd = {r_sock->get(), POLLIN, 0};
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.cc b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
index 08526468e..a4e3371f4 100644
--- a/test/syscalls/linux/socket_ipv6_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
@@ -47,29 +47,25 @@ TEST_P(IPv6UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) {
int level = SOL_IPV6;
int type = IPV6_RECVORIGDSTADDR;
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Retrieve the port bound by the receiver.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
- ASSERT_THAT(
- connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(connect(sender->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Get address and port bound by the sender.
sockaddr_storage sender_addr_storage;
socklen_t sender_addr_len = sizeof(sender_addr_storage);
- ASSERT_THAT(getsockname(sender->get(),
- reinterpret_cast<sockaddr*>(&sender_addr_storage),
+ ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&sender_addr_storage),
&sender_addr_len),
SyscallSucceeds());
ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in6));
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
index 7364a1ea5..8390f7c3b 100644
--- a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
@@ -24,13 +24,11 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver_addr = V6Any();
- ASSERT_THAT(
- bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(receiver->get(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
@@ -50,8 +48,7 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
// Set the sender to the loopback interface.
auto sender_addr = V6Loopback();
ASSERT_THAT(
- bind(sender->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
+ bind(sender->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
SyscallSucceeds());
// Send a multicast packet.
@@ -60,10 +57,10 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
reinterpret_cast<sockaddr_in6*>(&receiver_addr.addr)->sin6_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
@@ -77,10 +74,10 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
&group_req, sizeof(group_req)),
SyscallSucceeds());
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&send_addr.addr), send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc
index 2ee218231..48aace78a 100644
--- a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc
@@ -44,9 +44,9 @@ TEST_P(IPv6UDPUnboundSocketNetlinkTest, JoinSubnet) {
reinterpret_cast<sockaddr_in6*>(&sender_addr.addr)
->sin6_addr.s6_addr));
auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- EXPECT_THAT(bind(sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
- sender_addr.addr_len),
- SyscallFailsWithErrno(EADDRNOTAVAIL));
+ EXPECT_THAT(
+ bind(sock->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
}
} // namespace testing
diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc
index 538ee2268..0743322ac 100644
--- a/test/syscalls/linux/socket_stream_blocking.cc
+++ b/test/syscalls/linux/socket_stream_blocking.cc
@@ -68,7 +68,7 @@ TEST_P(BlockingStreamSocketPairTest, BlockPartialWriteClosed) {
// Random save may interrupt the call to sendmsg() in SendLargeSendMsg(),
// causing the write to be incomplete and the test to hang.
-TEST_P(BlockingStreamSocketPairTest, SendMsgTooLarge_NoRandomSave) {
+TEST_P(BlockingStreamSocketPairTest, SendMsgTooLarge) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
int sndbuf;
@@ -102,7 +102,7 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) {
// Test that MSG_WAITALL causes recv to block until all requested data is
// received. Random save can interrupt blocking and cause received data to be
// returned, even if the amount received is less than the full requested amount.
-TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll_NoRandomSave) {
+TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[100];
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index b2a96086c..9e3a129cf 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -82,8 +82,7 @@ Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain,
RETURN_ERROR_IF_SYSCALL_FAIL(bound = socket(domain, type, protocol));
MaybeSave(); // Successful socket creation.
RETURN_ERROR_IF_SYSCALL_FAIL(
- bind(bound, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)));
+ bind(bound, AsSockAddr(&bind_addr), sizeof(bind_addr)));
MaybeSave(); // Successful bind.
RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5));
MaybeSave(); // Successful listen.
@@ -92,8 +91,7 @@ Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain,
RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol));
MaybeSave(); // Successful socket creation.
RETURN_ERROR_IF_SYSCALL_FAIL(
- connect(connected, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)));
+ connect(connected, AsSockAddr(&bind_addr), sizeof(bind_addr)));
MaybeSave(); // Successful connect.
int accepted;
@@ -145,22 +143,22 @@ Creator<SocketPair> BidirectionalBindSocketPairCreator(bool abstract,
RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol));
MaybeSave(); // Successful socket creation.
RETURN_ERROR_IF_SYSCALL_FAIL(
- bind(sock1, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1)));
+ bind(sock1, AsSockAddr(&addr1), sizeof(addr1)));
MaybeSave(); // Successful bind.
int sock2;
RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol));
MaybeSave(); // Successful socket creation.
RETURN_ERROR_IF_SYSCALL_FAIL(
- bind(sock2, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2)));
+ bind(sock2, AsSockAddr(&addr2), sizeof(addr2)));
MaybeSave(); // Successful bind.
- RETURN_ERROR_IF_SYSCALL_FAIL(connect(
- sock1, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2)));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ connect(sock1, AsSockAddr(&addr2), sizeof(addr2)));
MaybeSave(); // Successful connect.
- RETURN_ERROR_IF_SYSCALL_FAIL(connect(
- sock2, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1)));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ connect(sock2, AsSockAddr(&addr1), sizeof(addr1)));
MaybeSave(); // Successful connect.
// Cleanup no longer needed resources.
@@ -206,15 +204,15 @@ Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type,
int sock1;
RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol));
MaybeSave(); // Successful socket creation.
- RETURN_ERROR_IF_SYSCALL_FAIL(connect(
- sock1, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ connect(sock1, AsSockAddr(&addr), sizeof(addr)));
MaybeSave(); // Successful connect.
int sock2;
RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol));
MaybeSave(); // Successful socket creation.
- RETURN_ERROR_IF_SYSCALL_FAIL(connect(
- sock2, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ connect(sock2, AsSockAddr(&addr), sizeof(addr)));
MaybeSave(); // Successful connect.
// Make and close another socketpair to ensure that the duped ends of the
@@ -228,8 +226,8 @@ Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type,
for (int i = 0; i < 2; i++) {
int sock;
RETURN_ERROR_IF_SYSCALL_FAIL(sock = socket(domain, type, protocol));
- RETURN_ERROR_IF_SYSCALL_FAIL(connect(
- sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ connect(sock, AsSockAddr(&addr), sizeof(addr)));
RETURN_ERROR_IF_SYSCALL_FAIL(close(sock));
}
@@ -308,11 +306,9 @@ template <typename T>
PosixErrorOr<T> BindIP(int fd, bool dual_stack) {
T addr = {};
LocalhostAddr(&addr, dual_stack);
- RETURN_ERROR_IF_SYSCALL_FAIL(
- bind(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
+ RETURN_ERROR_IF_SYSCALL_FAIL(bind(fd, AsSockAddr(&addr), sizeof(addr)));
socklen_t addrlen = sizeof(addr);
- RETURN_ERROR_IF_SYSCALL_FAIL(
- getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen));
+ RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(fd, AsSockAddr(&addr), &addrlen));
return addr;
}
@@ -329,9 +325,8 @@ CreateTCPConnectAcceptSocketPair(int bound, int connected, int type,
bool dual_stack, T bind_addr) {
int connect_result = 0;
RETURN_ERROR_IF_SYSCALL_FAIL(
- (connect_result = RetryEINTR(connect)(
- connected, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr))) == -1 &&
+ (connect_result = RetryEINTR(connect)(connected, AsSockAddr(&bind_addr),
+ sizeof(bind_addr))) == -1 &&
errno == EINPROGRESS
? 0
: connect_result);
@@ -703,7 +698,7 @@ PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
}
RETURN_ERROR_IF_SYSCALL_FAIL(
- bind(fd.get(), reinterpret_cast<sockaddr*>(&storage), storage_size));
+ bind(fd.get(), AsSockAddr(&storage), storage_size));
// If the user specified 0 as the port, we will return the port that the
// kernel gave us, otherwise we will validate that this socket bound to the
@@ -711,8 +706,7 @@ PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
sockaddr_storage bound_storage = {};
socklen_t bound_storage_size = sizeof(bound_storage);
RETURN_ERROR_IF_SYSCALL_FAIL(
- getsockname(fd.get(), reinterpret_cast<sockaddr*>(&bound_storage),
- &bound_storage_size));
+ getsockname(fd.get(), AsSockAddr(&bound_storage), &bound_storage_size));
int available_port = -1;
if (bound_storage.ss_family == AF_INET) {
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index b3ab286b8..f7ba90130 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -520,6 +520,20 @@ uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
ssize_t payload_len);
+// Convenient functions for reinterpreting common types to sockaddr pointer.
+inline sockaddr* AsSockAddr(sockaddr_storage* s) {
+ return reinterpret_cast<sockaddr*>(s);
+}
+inline sockaddr* AsSockAddr(sockaddr_in* s) {
+ return reinterpret_cast<sockaddr*>(s);
+}
+inline sockaddr* AsSockAddr(sockaddr_in6* s) {
+ return reinterpret_cast<sockaddr*>(s);
+}
+inline sockaddr* AsSockAddr(sockaddr_un* s) {
+ return reinterpret_cast<sockaddr*>(s);
+}
+
namespace internal {
PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
SocketType type, bool reuse_addr);
diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc
index 884319e1d..9425e87a6 100644
--- a/test/syscalls/linux/socket_unix_non_stream.cc
+++ b/test/syscalls/linux/socket_unix_non_stream.cc
@@ -239,7 +239,7 @@ TEST_P(UnixNonStreamSocketPairTest, SendTimeout) {
SyscallSucceeds());
// The buffer size should be big enough to avoid many iterations in the next
- // loop. Otherwise, this will slow down cooperative_save tests.
+ // loop. Otherwise, this will slow down save tests.
std::vector<char> buf(kPageSize);
for (;;) {
int ret;
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index e5730a606..c85f6da0b 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -883,7 +883,7 @@ TEST(SpliceTest, FromPipeToDevZero) {
static volatile int signaled = 0;
void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; }
-TEST(SpliceTest, ToPipeWithSmallCapacityDoesNotSpin_NoRandomSave) {
+TEST(SpliceTest, ToPipeWithSmallCapacityDoesNotSpin) {
// Writes to a pipe that are less than PIPE_BUF must be atomic. This test
// creates a pipe with only 128 bytes of capacity (< PIPE_BUF) and checks that
// splicing to the pipe does not spin. See b/170743336.
diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc
index ea219a091..9f6c59446 100644
--- a/test/syscalls/linux/symlink.cc
+++ b/test/syscalls/linux/symlink.cc
@@ -248,7 +248,7 @@ TEST(SymlinkTest, PwriteToSymlink) {
EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds());
}
-TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) {
+TEST(SymlinkTest, SymlinkAtDegradedPermissions) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
@@ -299,7 +299,7 @@ TEST(SymlinkTest, ReadlinkAtDirWithOpath) {
EXPECT_EQ(0, strncmp("/dangling", buf.data(), linksize));
}
-TEST(SymlinkTest, ReadlinkAtDegradedPermissions_NoRandomSave) {
+TEST(SymlinkTest, ReadlinkAtDegradedPermissions) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 7341cf1a6..011b60f0e 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -139,20 +139,16 @@ void TcpSocketTest::SetUp() {
socklen_t addrlen = sizeof(addr);
// Bind to some port then start listening.
- ASSERT_THAT(
- bind(listener_, reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(listener_, AsSockAddr(&addr), addrlen), SyscallSucceeds());
ASSERT_THAT(listen(listener_, SOMAXCONN), SyscallSucceeds());
// Get the address we're listening on, then connect to it. We need to do this
// because we're allowing the stack to pick a port for us.
- ASSERT_THAT(getsockname(listener_, reinterpret_cast<struct sockaddr*>(&addr),
- &addrlen),
+ ASSERT_THAT(getsockname(listener_, AsSockAddr(&addr), &addrlen),
SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(
- first_fd, reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(first_fd, AsSockAddr(&addr), addrlen),
SyscallSucceeds());
// Get the initial send buffer size.
@@ -229,10 +225,9 @@ TEST_P(TcpSocketTest, SenderAddressIgnored) {
socklen_t addrlen = sizeof(addr);
memset(&addr, 0, sizeof(addr));
- ASSERT_THAT(
- RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), 0,
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
- SyscallSucceedsWithValue(3));
+ ASSERT_THAT(RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), 0,
+ AsSockAddr(&addr), &addrlen),
+ SyscallSucceedsWithValue(3));
// Check that addr remains zeroed-out.
const char* ptr = reinterpret_cast<char*>(&addr);
@@ -250,10 +245,9 @@ TEST_P(TcpSocketTest, SenderAddressIgnoredOnPeek) {
socklen_t addrlen = sizeof(addr);
memset(&addr, 0, sizeof(addr));
- ASSERT_THAT(
- RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), MSG_PEEK,
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
- SyscallSucceedsWithValue(3));
+ ASSERT_THAT(RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), MSG_PEEK,
+ AsSockAddr(&addr), &addrlen),
+ SyscallSucceedsWithValue(3));
// Check that addr remains zeroed-out.
const char* ptr = reinterpret_cast<char*>(&addr);
@@ -268,10 +262,9 @@ TEST_P(TcpSocketTest, SendtoAddressIgnored) {
addr.ss_family = GetParam(); // FIXME(b/63803955)
char data = '\0';
- EXPECT_THAT(
- RetryEINTR(sendto)(first_fd, &data, sizeof(data), 0,
- reinterpret_cast<sockaddr*>(&addr), sizeof(addr)),
- SyscallSucceedsWithValue(1));
+ EXPECT_THAT(RetryEINTR(sendto)(first_fd, &data, sizeof(data), 0,
+ AsSockAddr(&addr), sizeof(addr)),
+ SyscallSucceedsWithValue(1));
}
TEST_P(TcpSocketTest, WritevZeroIovec) {
@@ -331,7 +324,7 @@ TEST_P(TcpSocketTest, NonblockingLargeWrite) {
// Test that a blocking write with a buffer that is larger than the send buffer
// will block until the entire buffer is sent.
-TEST_P(TcpSocketTest, BlockingLargeWrite_NoRandomSave) {
+TEST_P(TcpSocketTest, BlockingLargeWrite) {
// Allocate a buffer three times the size of the send buffer on the heap. We
// do this as a vector to avoid allocating on the stack.
int size = 3 * sendbuf_size_;
@@ -415,7 +408,7 @@ TEST_P(TcpSocketTest, NonblockingLargeSend) {
}
// Same test as above, but calls send instead of write.
-TEST_P(TcpSocketTest, BlockingLargeSend_NoRandomSave) {
+TEST_P(TcpSocketTest, BlockingLargeSend) {
// Allocate a buffer three times the size of the send buffer. We do this on
// with a vector to avoid allocating on the stack.
int size = 3 * sendbuf_size_;
@@ -869,10 +862,9 @@ TEST_P(SimpleTcpSocketTest, SendtoWithAddressUnconnected) {
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
char data = '\0';
- EXPECT_THAT(
- RetryEINTR(sendto)(fd, &data, sizeof(data), 0,
- reinterpret_cast<sockaddr*>(&addr), sizeof(addr)),
- SyscallFailsWithErrno(EPIPE));
+ EXPECT_THAT(RetryEINTR(sendto)(fd, &data, sizeof(data), 0, AsSockAddr(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(EPIPE));
}
TEST_P(SimpleTcpSocketTest, GetPeerNameUnconnected) {
@@ -883,7 +875,7 @@ TEST_P(SimpleTcpSocketTest, GetPeerNameUnconnected) {
sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ EXPECT_THAT(getpeername(fd, AsSockAddr(&addr), &addrlen),
SyscallFailsWithErrno(ENOTCONN));
}
@@ -974,24 +966,20 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRetry) {
socklen_t addrlen = sizeof(addr);
// Bind to some port but don't listen yet.
- ASSERT_THAT(
- bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(listener.get(), AsSockAddr(&addr), addrlen),
+ SyscallSucceeds());
// Get the address we're bound to, then connect to it. We need to do this
// because we're allowing the stack to pick a port for us.
- ASSERT_THAT(getsockname(listener.get(),
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ ASSERT_THAT(getsockname(listener.get(), AsSockAddr(&addr), &addrlen),
SyscallSucceeds());
FileDescriptor connector =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
// Verify that connect fails.
- ASSERT_THAT(
- RetryEINTR(connect)(connector.get(),
- reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallFailsWithErrno(ECONNREFUSED));
+ ASSERT_THAT(RetryEINTR(connect)(connector.get(), AsSockAddr(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
// Now start listening
ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
@@ -1000,17 +988,14 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRetry) {
// failed first connect should succeed.
if (IsRunningOnGvisor()) {
ASSERT_THAT(
- RetryEINTR(connect)(connector.get(),
- reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ RetryEINTR(connect)(connector.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
return;
}
// Verify that connect now succeeds.
- ASSERT_THAT(
- RetryEINTR(connect)(connector.get(),
- reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connector.get(), AsSockAddr(&addr), addrlen),
+ SyscallSucceeds());
// Accept the connection.
const FileDescriptor accepted =
@@ -1030,13 +1015,11 @@ PosixErrorOr<FileDescriptor> nonBlockingConnectNoListener(const int family,
int b_sock;
RETURN_ERROR_IF_SYSCALL_FAIL(b_sock = socket(family, sock_type, IPPROTO_TCP));
FileDescriptor b(b_sock);
- EXPECT_THAT(bind(b.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(bind(b.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds());
// Get the address bound by the listening socket.
- EXPECT_THAT(
- getsockname(b.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(b.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
// Now create another socket and issue a connect on this one. This connect
// should fail as there is no listener.
@@ -1046,8 +1029,7 @@ PosixErrorOr<FileDescriptor> nonBlockingConnectNoListener(const int family,
// Now connect to the bound address and this should fail as nothing
// is listening on the bound address.
- EXPECT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ EXPECT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(EINPROGRESS));
// Wait for the connect to fail.
@@ -1078,8 +1060,7 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
opts &= ~O_NONBLOCK;
EXPECT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
// Try connecting again.
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
}
@@ -1094,8 +1075,7 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerRead) {
unsigned char c;
ASSERT_THAT(read(s.get(), &c, 1), SyscallFailsWithErrno(ECONNREFUSED));
ASSERT_THAT(read(s.get(), &c, 1), SyscallSucceedsWithValue(0));
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
}
@@ -1111,12 +1091,11 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerPeek) {
ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK),
SyscallFailsWithErrno(ECONNREFUSED));
ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK), SyscallSucceedsWithValue(0));
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
}
-TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) {
+TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv) {
// Initialize address to the loopback one.
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
@@ -1125,15 +1104,11 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) {
const FileDescriptor s =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(
- (bind)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT((bind)(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds());
// Get the bound port.
- ASSERT_THAT(
- getsockname(s.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallSucceeds());
constexpr int kBufSz = 1 << 20; // 1 MiB
@@ -1168,7 +1143,7 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) {
EXPECT_EQ(read_bytes, kBufSz);
}
-TEST_P(SimpleTcpSocketTest, SelfConnectSend_NoRandomSave) {
+TEST_P(SimpleTcpSocketTest, SelfConnectSend) {
// Initialize address to the loopback one.
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
@@ -1182,14 +1157,11 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSend_NoRandomSave) {
setsockopt(s.get(), SOL_TCP, TCP_MAXSEG, &max_seg, sizeof(max_seg)),
SyscallSucceeds());
- ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds());
// Get the bound port.
- ASSERT_THAT(
- getsockname(s.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallSucceeds());
std::vector<char> writebuf(512 << 10); // 512 KiB.
@@ -1213,9 +1185,8 @@ void NonBlockingConnect(int family, int16_t pollMask) {
socklen_t addrlen = sizeof(addr);
// Bind to some port then start listening.
- ASSERT_THAT(
- bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(listener.get(), AsSockAddr(&addr), addrlen),
+ SyscallSucceeds());
ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
@@ -1228,12 +1199,10 @@ void NonBlockingConnect(int family, int16_t pollMask) {
opts |= O_NONBLOCK;
ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
- ASSERT_THAT(getsockname(listener.get(),
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ ASSERT_THAT(getsockname(listener.get(), AsSockAddr(&addr), &addrlen),
SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(EINPROGRESS));
int t;
@@ -1276,21 +1245,18 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) {
socklen_t addrlen = sizeof(addr);
// Bind to some port then start listening.
- ASSERT_THAT(
- bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(listener.get(), AsSockAddr(&addr), addrlen),
+ SyscallSucceeds());
ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
- ASSERT_THAT(getsockname(listener.get(),
- reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ ASSERT_THAT(getsockname(listener.get(), AsSockAddr(&addr), &addrlen),
SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(EINPROGRESS));
int t;
@@ -1305,12 +1271,10 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) {
EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
SyscallSucceedsWithValue(1));
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(EISCONN));
}
@@ -1325,8 +1289,7 @@ TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) {
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(ECONNREFUSED));
// Avoiding triggering save in destructor of s.
@@ -1346,17 +1309,14 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) {
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t bound_addrlen = sizeof(bound_addr);
- ASSERT_THAT(
- bind(bound_s.get(), reinterpret_cast<struct sockaddr*>(&bound_addr),
- bound_addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen),
+ SyscallSucceeds());
// Get the addresses the socket is bound to because the port is chosen by the
// stack.
- ASSERT_THAT(getsockname(bound_s.get(),
- reinterpret_cast<struct sockaddr*>(&bound_addr),
- &bound_addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(
+ getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen),
+ SyscallSucceeds());
// Create, initialize, and bind the socket that is used to test connecting to
// the non-listening port.
@@ -1367,16 +1327,13 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) {
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t client_addrlen = sizeof(client_addr);
+ ASSERT_THAT(bind(client_s.get(), AsSockAddr(&client_addr), client_addrlen),
+ SyscallSucceeds());
+
ASSERT_THAT(
- bind(client_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr),
- client_addrlen),
+ getsockname(client_s.get(), AsSockAddr(&client_addr), &client_addrlen),
SyscallSucceeds());
- ASSERT_THAT(getsockname(client_s.get(),
- reinterpret_cast<struct sockaddr*>(&client_addr),
- &client_addrlen),
- SyscallSucceeds());
-
// Now the test: connect to the bound but not listening socket with the
// client socket. The bound socket should return a RST and cause the client
// socket to return an error and clean itself up immediately.
@@ -1392,10 +1349,8 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) {
// Test binding to the address from the client socket. This should be okay
// if it was dropped correctly.
- ASSERT_THAT(
- bind(new_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr),
- client_addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(new_s.get(), AsSockAddr(&client_addr), client_addrlen),
+ SyscallSucceeds());
// Attempt #2, with the new socket and reused addr our connect should fail in
// the same way as before, not with an EADDRINUSE.
@@ -1428,8 +1383,7 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) {
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
- ASSERT_THAT(RetryEINTR(connect)(
- s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
SyscallFailsWithErrno(EINPROGRESS));
// We don't need to specify any events to get POLLHUP or POLLERR as these
@@ -1720,8 +1674,7 @@ TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
- RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
- addrlen);
+ RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen);
int buf_sz = 1 << 18;
EXPECT_THAT(
setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
@@ -2034,8 +1987,7 @@ TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) {
socklen_t addrlen = sizeof(addr);
// Bind to some port then start listening.
- ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
+ ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds());
ASSERT_THAT(listen(s.get(), SOMAXCONN), SyscallSucceeds());
@@ -2062,10 +2014,8 @@ TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) {
auto do_connect = [&addr, addrlen]() {
FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
Socket(addr.ss_family, SOCK_STREAM, IPPROTO_TCP));
- ASSERT_THAT(
- RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
- addrlen),
- SyscallFailsWithErrno(ECONNREFUSED));
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
};
do_connect();
// Test the v4 mapped address as well.
diff --git a/test/syscalls/linux/timerfd.cc b/test/syscalls/linux/timerfd.cc
index c4f8fdd7a..072c92797 100644
--- a/test/syscalls/linux/timerfd.cc
+++ b/test/syscalls/linux/timerfd.cc
@@ -114,7 +114,7 @@ TEST_P(TimerfdTest, BlockingRead) {
EXPECT_GE((end_time - start_time) + TimerSlack(), kDelay);
}
-TEST_P(TimerfdTest, NonblockingRead_NoRandomSave) {
+TEST_P(TimerfdTest, NonblockingRead) {
constexpr absl::Duration kDelay = absl::Seconds(5);
auto const tfd =
diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc
index 17832c47d..5db0b8276 100644
--- a/test/syscalls/linux/truncate.cc
+++ b/test/syscalls/linux/truncate.cc
@@ -208,7 +208,7 @@ TEST(TruncateTest, FtruncateWithOpath) {
// ftruncate(2) should succeed as long as the file descriptor is writeable,
// regardless of whether the file permissions allow writing.
-TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) {
+TEST(TruncateTest, FtruncateWithoutWritePermission) {
// Drop capabilities that allow us to override file permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
@@ -230,7 +230,7 @@ TEST(TruncateTest, TruncateNonExist) {
EXPECT_THAT(truncate("/foo/bar", 0), SyscallFailsWithErrno(ENOENT));
}
-TEST(TruncateTest, FtruncateVirtualTmp_NoRandomSave) {
+TEST(TruncateTest, FtruncateVirtualTmp) {
auto temp_file = NewTempAbsPathInDir("/dev/shm");
const DisableSave ds; // Incompatible permissions.
const FileDescriptor fd =
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
index 13ed0d68a..6e3a00d2c 100644
--- a/test/syscalls/linux/tuntap.cc
+++ b/test/syscalls/linux/tuntap.cc
@@ -349,9 +349,8 @@ TEST_F(TuntapTest, PingKernel) {
};
while (1) {
inpkt r = {};
- int nread = read(fd.get(), &r, sizeof(r));
- EXPECT_THAT(nread, SyscallSucceeds());
- long unsigned int n = static_cast<long unsigned int>(nread);
+ size_t n;
+ EXPECT_THAT(n = read(fd.get(), &r, sizeof(r)), SyscallSucceeds());
if (n < sizeof(pihdr)) {
std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
@@ -397,8 +396,7 @@ TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
.sin_port = htons(42),
.sin_addr = {.s_addr = kTapPeerIPAddr},
};
- ASSERT_THAT(sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote),
- sizeof(remote)),
+ ASSERT_THAT(sendto(sock, "hello", 5, 0, AsSockAddr(&remote), sizeof(remote)),
SyscallSucceeds());
struct inpkt {
@@ -409,9 +407,8 @@ TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
};
while (1) {
inpkt r = {};
- int nread = read(fd.get(), &r, sizeof(r));
- EXPECT_THAT(nread, SyscallSucceeds());
- long unsigned int n = static_cast<long unsigned int>(nread);
+ size_t n;
+ EXPECT_THAT(n = read(fd.get(), &r, sizeof(r)), SyscallSucceeds());
if (n < sizeof(pihdr)) {
std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
@@ -498,7 +495,7 @@ TEST_F(TuntapTest, WriteHangBug155928773) {
.sin_addr = {.s_addr = kTapIPAddr},
};
// Return values do not matter in this test.
- connect(sock, reinterpret_cast<struct sockaddr*>(&remote), sizeof(remote));
+ connect(sock, AsSockAddr(&remote), sizeof(remote));
write(sock, "hello", 5);
}
diff --git a/test/syscalls/linux/udp_bind.cc b/test/syscalls/linux/udp_bind.cc
index 6d92bdbeb..f68d78aa2 100644
--- a/test/syscalls/linux/udp_bind.cc
+++ b/test/syscalls/linux/udp_bind.cc
@@ -83,27 +83,24 @@ TEST_P(SendtoTest, Sendto) {
ASSERT_NO_ERRNO_AND_VALUE(Socket(param.recv_domain, SOCK_DGRAM, 0));
if (param.send_addr_len > 0) {
- ASSERT_THAT(bind(s1.get(), reinterpret_cast<sockaddr*>(&param.send_addr),
- param.send_addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(s1.get(), AsSockAddr(&param.send_addr), param.send_addr_len),
+ SyscallSucceeds());
}
if (param.connect_addr_len > 0) {
- ASSERT_THAT(
- connect(s1.get(), reinterpret_cast<sockaddr*>(&param.connect_addr),
- param.connect_addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(connect(s1.get(), AsSockAddr(&param.connect_addr),
+ param.connect_addr_len),
+ SyscallSucceeds());
}
- ASSERT_THAT(bind(s2.get(), reinterpret_cast<sockaddr*>(&param.recv_addr),
- param.recv_addr_len),
+ ASSERT_THAT(bind(s2.get(), AsSockAddr(&param.recv_addr), param.recv_addr_len),
SyscallSucceeds());
struct sockaddr_storage real_recv_addr = {};
socklen_t real_recv_addr_len = param.recv_addr_len;
ASSERT_THAT(
- getsockname(s2.get(), reinterpret_cast<sockaddr*>(&real_recv_addr),
- &real_recv_addr_len),
+ getsockname(s2.get(), AsSockAddr(&real_recv_addr), &real_recv_addr_len),
SyscallSucceeds());
ASSERT_EQ(real_recv_addr_len, param.recv_addr_len);
@@ -116,23 +113,22 @@ TEST_P(SendtoTest, Sendto) {
char buf[20] = {};
if (!param.sendto_errnos.empty()) {
- ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr),
- param.sendto_addr_len),
- SyscallFailsWithErrno(ElementOf(param.sendto_errnos)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0,
+ AsSockAddr(&sendto_addr), param.sendto_addr_len),
+ SyscallFailsWithErrno(ElementOf(param.sendto_errnos)));
return;
}
- ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr),
- param.sendto_addr_len),
- SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0,
+ AsSockAddr(&sendto_addr), param.sendto_addr_len),
+ SyscallSucceedsWithValue(sizeof(buf)));
struct sockaddr_storage got_addr = {};
socklen_t got_addr_len = sizeof(sockaddr_storage);
ASSERT_THAT(RetryEINTR(recvfrom)(s2.get(), buf, sizeof(buf), 0,
- reinterpret_cast<sockaddr*>(&got_addr),
- &got_addr_len),
+ AsSockAddr(&got_addr), &got_addr_len),
SyscallSucceedsWithValue(sizeof(buf)));
ASSERT_GT(got_addr_len, sizeof(sockaddr_in_common));
@@ -140,8 +136,7 @@ TEST_P(SendtoTest, Sendto) {
struct sockaddr_storage sender_addr = {};
socklen_t sender_addr_len = sizeof(sockaddr_storage);
- ASSERT_THAT(getsockname(s1.get(), reinterpret_cast<sockaddr*>(&sender_addr),
- &sender_addr_len),
+ ASSERT_THAT(getsockname(s1.get(), AsSockAddr(&sender_addr), &sender_addr_len),
SyscallSucceeds());
ASSERT_GT(sender_addr_len, sizeof(sockaddr_in_common));
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 16eeeb5c6..18f566eec 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -138,7 +138,7 @@ void UdpSocketTest::SetUp() {
bind_ =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_));
- bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+ bind_addr_ = AsSockAddr(&bind_addr_storage_);
sock_ =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
@@ -153,15 +153,13 @@ int UdpSocketTest::GetFamily() {
PosixError UdpSocketTest::BindLoopback() {
bind_addr_storage_ = InetLoopbackAddr();
- struct sockaddr* bind_addr_ =
- reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+ struct sockaddr* bind_addr_ = AsSockAddr(&bind_addr_storage_);
return BindSocket(bind_.get(), bind_addr_);
}
PosixError UdpSocketTest::BindAny() {
bind_addr_storage_ = InetAnyAddr();
- struct sockaddr* bind_addr_ =
- reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+ struct sockaddr* bind_addr_ = AsSockAddr(&bind_addr_storage_);
return BindSocket(bind_.get(), bind_addr_);
}
@@ -195,7 +193,7 @@ socklen_t UdpSocketTest::GetAddrLength() {
sockaddr_storage UdpSocketTest::InetAnyAddr() {
struct sockaddr_storage addr;
memset(&addr, 0, sizeof(addr));
- reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily();
+ AsSockAddr(&addr)->sa_family = GetFamily();
if (GetFamily() == AF_INET) {
auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
@@ -213,7 +211,7 @@ sockaddr_storage UdpSocketTest::InetAnyAddr() {
sockaddr_storage UdpSocketTest::InetLoopbackAddr() {
struct sockaddr_storage addr;
memset(&addr, 0, sizeof(addr));
- reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily();
+ AsSockAddr(&addr)->sa_family = GetFamily();
if (GetFamily() == AF_INET) {
auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
@@ -229,7 +227,7 @@ sockaddr_storage UdpSocketTest::InetLoopbackAddr() {
void UdpSocketTest::Disconnect(int sockfd) {
sockaddr_storage addr_storage = InetAnyAddr();
- sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ sockaddr* addr = AsSockAddr(&addr_storage);
socklen_t addrlen = sizeof(addr_storage);
addr->sa_family = AF_UNSPEC;
@@ -265,19 +263,16 @@ TEST_P(UdpSocketTest, Getsockname) {
// Check that we're not bound.
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(bind_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
struct sockaddr_storage any = InetAnyAddr();
- EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_),
- 0);
+ EXPECT_EQ(memcmp(&addr, AsSockAddr(&any), addrlen_), 0);
ASSERT_NO_ERRNO(BindLoopback());
- EXPECT_THAT(
- getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(bind_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
@@ -289,17 +284,15 @@ TEST_P(UdpSocketTest, Getpeername) {
// Check that we're not connected.
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
// Connect, then check that we get the right address.
ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
}
@@ -322,9 +315,8 @@ TEST_P(UdpSocketTest, SendNotConnected) {
// Check that we're bound now.
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_NE(*Port(&addr), 0);
}
@@ -338,9 +330,8 @@ TEST_P(UdpSocketTest, ConnectBinds) {
// Check that we're bound now.
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_NE(*Port(&addr), 0);
}
@@ -361,9 +352,8 @@ TEST_P(UdpSocketTest, Bind) {
// Check that we're still bound to the original address.
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(bind_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
}
@@ -383,7 +373,7 @@ TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
// same time.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
socklen_t addrlen = sizeof(addr_storage);
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
FileDescriptor s =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds());
@@ -417,7 +407,7 @@ TEST_P(UdpSocketTest, ConnectSimultaneousWriteToInvalidPort) {
// same time.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
socklen_t addrlen = sizeof(addr_storage);
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
FileDescriptor s =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds());
@@ -465,18 +455,17 @@ TEST_P(UdpSocketTest, ReceiveAfterDisconnect) {
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
// Send from sock to bind_.
char buf[512];
RandomizeBuffer(buf, sizeof(buf));
- ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0,
- reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(
+ sendto(bind_.get(), buf, sizeof(buf), 0, AsSockAddr(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(buf)));
// Receive the data.
char received[sizeof(buf)];
@@ -499,21 +488,18 @@ TEST_P(UdpSocketTest, Connect) {
// Check that we're connected to the right peer.
struct sockaddr_storage peer;
socklen_t peerlen = sizeof(peer);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallSucceeds());
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen),
+ SyscallSucceeds());
EXPECT_EQ(peerlen, addrlen_);
EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0);
// Try to bind after connect.
struct sockaddr_storage any = InetAnyAddr();
- EXPECT_THAT(
- bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
- SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(bind(sock_.get(), AsSockAddr(&any), addrlen_),
+ SyscallFailsWithErrno(EINVAL));
struct sockaddr_storage bind2_storage = InetLoopbackAddr();
- struct sockaddr* bind2_addr =
- reinterpret_cast<struct sockaddr*>(&bind2_storage);
+ struct sockaddr* bind2_addr = AsSockAddr(&bind2_storage);
FileDescriptor bind2 =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr));
@@ -523,9 +509,8 @@ TEST_P(UdpSocketTest, Connect) {
// Check that peer name changed.
peerlen = sizeof(peer);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallSucceeds());
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen),
+ SyscallSucceeds());
EXPECT_EQ(peerlen, addrlen_);
EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0);
}
@@ -535,15 +520,13 @@ TEST_P(UdpSocketTest, ConnectAnyZero) {
SKIP_IF(IsRunningOnGvisor());
struct sockaddr_storage any = InetAnyAddr();
- EXPECT_THAT(
- connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
- SyscallSucceeds());
+ EXPECT_THAT(connect(sock_.get(), AsSockAddr(&any), addrlen_),
+ SyscallSucceeds());
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
}
TEST_P(UdpSocketTest, ConnectAnyWithPort) {
@@ -552,24 +535,21 @@ TEST_P(UdpSocketTest, ConnectAnyWithPort) {
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
}
TEST_P(UdpSocketTest, DisconnectAfterConnectAny) {
// TODO(138658473): Enable when we can connect to port 0 with gVisor.
SKIP_IF(IsRunningOnGvisor());
struct sockaddr_storage any = InetAnyAddr();
- EXPECT_THAT(
- connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
- SyscallSucceeds());
+ EXPECT_THAT(connect(sock_.get(), AsSockAddr(&any), addrlen_),
+ SyscallSucceeds());
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
Disconnect(sock_.get());
}
@@ -580,9 +560,8 @@ TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) {
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr));
@@ -595,7 +574,7 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) {
// Bind to the next port above bind_.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr));
@@ -604,15 +583,14 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) {
struct sockaddr_storage unspec = {};
unspec.ss_family = AF_UNSPEC;
- EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec),
- sizeof(unspec.ss_family)),
- SyscallSucceeds());
+ EXPECT_THAT(
+ connect(sock_.get(), AsSockAddr(&unspec), sizeof(unspec.ss_family)),
+ SyscallSucceeds());
// Check that we're still bound.
socklen_t addrlen = sizeof(unspec);
- EXPECT_THAT(
- getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&unspec), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0);
@@ -626,7 +604,7 @@ TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) {
ASSERT_NO_ERRNO(BindAny());
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
socklen_t addrlen = sizeof(addr);
@@ -653,7 +631,7 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
ASSERT_NO_ERRNO(BindLoopback());
struct sockaddr_storage any_storage = InetAnyAddr();
- struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage);
+ struct sockaddr* any = AsSockAddr(&any_storage);
SetPort(&any_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_NO_ERRNO(BindSocket(sock_.get(), any));
@@ -666,24 +644,22 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
// Check that we're still bound.
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(memcmp(&addr, any, addrlen), 0);
addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
}
TEST_P(UdpSocketTest, Disconnect) {
ASSERT_NO_ERRNO(BindLoopback());
struct sockaddr_storage any_storage = InetAnyAddr();
- struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage);
+ struct sockaddr* any = AsSockAddr(&any_storage);
SetPort(&any_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_NO_ERRNO(BindSocket(sock_.get(), any));
@@ -694,29 +670,25 @@ TEST_P(UdpSocketTest, Disconnect) {
// Check that we're connected to the right peer.
struct sockaddr_storage peer;
socklen_t peerlen = sizeof(peer);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallSucceeds());
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen),
+ SyscallSucceeds());
EXPECT_EQ(peerlen, addrlen_);
EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0);
// Try to disconnect.
struct sockaddr_storage addr = {};
addr.ss_family = AF_UNSPEC;
- EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr),
- sizeof(addr.ss_family)),
+ EXPECT_THAT(connect(sock_.get(), AsSockAddr(&addr), sizeof(addr.ss_family)),
SyscallSucceeds());
peerlen = sizeof(peer);
- EXPECT_THAT(
- getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen),
+ SyscallFailsWithErrno(ENOTCONN));
// Check that we're still bound.
socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
+ EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
EXPECT_EQ(addrlen, addrlen_);
EXPECT_EQ(*Port(&addr), *Port(&any_storage));
}
@@ -733,7 +705,7 @@ TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
ASSERT_NO_ERRNO(BindLoopback());
struct sockaddr_storage addr_storage = InetAnyAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
@@ -881,7 +853,7 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowed) {
ASSERT_NO_ERRNO(BindLoopback());
// Connect to loopback:bind_addr_+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
@@ -910,7 +882,7 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) {
// Connect to loopback:bind_addr_port+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
@@ -961,7 +933,7 @@ TEST_P(UdpSocketTest, SendAndReceiveConnected) {
// Connect to loopback:bind_addr_port+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
@@ -987,13 +959,13 @@ TEST_P(UdpSocketTest, ReceiveFromNotConnected) {
// Connect to loopback:bind_addr_port+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
// Bind sock to loopback:bind_addr_port+2.
struct sockaddr_storage addr2_storage = InetLoopbackAddr();
- struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage);
+ struct sockaddr* addr2 = AsSockAddr(&addr2_storage);
SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2);
ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds());
@@ -1013,7 +985,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
// Bind sock to loopback:bind_addr_port+2.
struct sockaddr_storage addr2_storage = InetLoopbackAddr();
- struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage);
+ struct sockaddr* addr2 = AsSockAddr(&addr2_storage);
SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2);
ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds());
@@ -1026,7 +998,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
// Connect to loopback:bind_addr_port+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
@@ -1050,7 +1022,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) {
// Connect to loopback:bind_addr_port+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
@@ -1069,7 +1041,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) {
struct sockaddr_storage addr2;
socklen_t addr2len = sizeof(addr2);
EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0,
- reinterpret_cast<sockaddr*>(&addr2), &addr2len),
+ AsSockAddr(&addr2), &addr2len),
SyscallSucceedsWithValue(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
EXPECT_EQ(addr2len, addrlen_);
@@ -1093,7 +1065,7 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) {
// Connect to loopback:bind_addr_port+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
@@ -1149,7 +1121,7 @@ TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) {
// Connect to loopback:bind_addr_port+1.
struct sockaddr_storage addr_storage = InetLoopbackAddr();
- struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ struct sockaddr* addr = AsSockAddr(&addr_storage);
SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
@@ -1932,13 +1904,8 @@ TEST_P(UdpSocketTest, RecvBufLimits) {
SyscallSucceeds());
}
- // Now set the limit to min * 4.
- int new_rcv_buf_sz = min * 4;
- if (!IsRunningOnGvisor() || IsRunningWithHostinet()) {
- // Linux doubles the value specified so just set to min * 2.
- new_rcv_buf_sz = min * 2;
- }
-
+ // Now set the limit to min * 2.
+ int new_rcv_buf_sz = min * 2;
ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz,
sizeof(new_rcv_buf_sz)),
SyscallSucceeds());
@@ -2051,68 +2018,57 @@ TEST_P(UdpSocketTest, SendToZeroPort) {
// Sending to an invalid port should fail.
SetPort(&addr, 0);
- EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0,
- reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
- SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(
+ sendto(sock_.get(), buf, sizeof(buf), 0, AsSockAddr(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EINVAL));
SetPort(&addr, 1234);
- EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0,
- reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
- SyscallSucceedsWithValue(sizeof(buf)));
+ EXPECT_THAT(
+ sendto(sock_.get(), buf, sizeof(buf), 0, AsSockAddr(&addr), sizeof(addr)),
+ SyscallSucceedsWithValue(sizeof(buf)));
}
TEST_P(UdpSocketTest, ConnectToZeroPortUnbound) {
struct sockaddr_storage addr = InetLoopbackAddr();
SetPort(&addr, 0);
- ASSERT_THAT(
- connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
- SyscallSucceeds());
+ ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_),
+ SyscallSucceeds());
}
TEST_P(UdpSocketTest, ConnectToZeroPortBound) {
struct sockaddr_storage addr = InetLoopbackAddr();
- ASSERT_NO_ERRNO(
- BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr)));
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), AsSockAddr(&addr)));
SetPort(&addr, 0);
- ASSERT_THAT(
- connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
- SyscallSucceeds());
+ ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_),
+ SyscallSucceeds());
socklen_t len = sizeof(sockaddr_storage);
- ASSERT_THAT(
- getsockname(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), &len),
- SyscallSucceeds());
+ ASSERT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &len),
+ SyscallSucceeds());
ASSERT_EQ(len, addrlen_);
}
TEST_P(UdpSocketTest, ConnectToZeroPortConnected) {
struct sockaddr_storage addr = InetLoopbackAddr();
- ASSERT_NO_ERRNO(
- BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr)));
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), AsSockAddr(&addr)));
// Connect to an address with non-zero port should succeed.
- ASSERT_THAT(
- connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
- SyscallSucceeds());
+ ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_),
+ SyscallSucceeds());
sockaddr_storage peername;
socklen_t peerlen = sizeof(peername);
- ASSERT_THAT(
- getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername),
- &peerlen),
- SyscallSucceeds());
+ ASSERT_THAT(getpeername(sock_.get(), AsSockAddr(&peername), &peerlen),
+ SyscallSucceeds());
ASSERT_EQ(peerlen, addrlen_);
ASSERT_EQ(memcmp(&peername, &addr, addrlen_), 0);
// However connect() to an address with port 0 will make the following
// getpeername() fail.
SetPort(&addr, 0);
- ASSERT_THAT(
- connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
- SyscallSucceeds());
- ASSERT_THAT(
- getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername),
- &peerlen),
- SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_),
+ SyscallSucceeds());
+ ASSERT_THAT(getpeername(sock_.get(), AsSockAddr(&peername), &peerlen),
+ SyscallFailsWithErrno(ENOTCONN));
}
INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest,
@@ -2133,8 +2089,7 @@ TEST(UdpInet6SocketTest, ConnectInet4Sockaddr) {
SyscallSucceeds());
sockaddr_storage sockname;
socklen_t len = sizeof(sockaddr_storage);
- ASSERT_THAT(getsockname(sock_.get(),
- reinterpret_cast<struct sockaddr*>(&sockname), &len),
+ ASSERT_THAT(getsockname(sock_.get(), AsSockAddr(&sockname), &len),
SyscallSucceeds());
ASSERT_EQ(sockname.ss_family, AF_INET6);
ASSERT_EQ(len, sizeof(sockaddr_in6));
diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc
index 061e2e0f1..7c301c305 100644
--- a/test/syscalls/linux/unlink.cc
+++ b/test/syscalls/linux/unlink.cc
@@ -64,7 +64,7 @@ TEST(UnlinkTest, AtDir) {
ASSERT_THAT(close(dirfd), SyscallSucceeds());
}
-TEST(UnlinkTest, AtDirDegradedPermissions_NoRandomSave) {
+TEST(UnlinkTest, AtDirDegradedPermissions) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
@@ -162,7 +162,7 @@ TEST(UnlinkTest, AtFile) {
EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds());
}
-TEST(UnlinkTest, OpenFile_NoRandomSave) {
+TEST(UnlinkTest, OpenFile) {
// We can't save unlinked file unless they are on tmpfs.
const DisableSave ds;
auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
diff --git a/test/syscalls/linux/verity_ioctl.cc b/test/syscalls/linux/verity_ioctl.cc
new file mode 100644
index 000000000..a81fe5724
--- /dev/null
+++ b/test/syscalls/linux/verity_ioctl.cc
@@ -0,0 +1,188 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <sys/mount.h>
+
+#include <iomanip>
+#include <sstream>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/mount_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifndef FS_IOC_ENABLE_VERITY
+#define FS_IOC_ENABLE_VERITY 1082156677
+#endif
+
+#ifndef FS_IOC_MEASURE_VERITY
+#define FS_IOC_MEASURE_VERITY 3221513862
+#endif
+
+#ifndef FS_VERITY_FL
+#define FS_VERITY_FL 1048576
+#endif
+
+#ifndef FS_IOC_GETFLAGS
+#define FS_IOC_GETFLAGS 2148034049
+#endif
+
+struct fsverity_digest {
+ __u16 digest_algorithm;
+ __u16 digest_size; /* input/output */
+ __u8 digest[];
+};
+
+constexpr int kMaxDigestSize = 64;
+constexpr int kDefaultDigestSize = 32;
+constexpr char kContents[] = "foobarbaz";
+
+class IoctlTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // Verity is implemented in VFS2.
+ SKIP_IF(IsRunningWithVFS1());
+
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ // Mount a tmpfs file system, to be wrapped by a verity fs.
+ tmpfs_dir_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(mount("", tmpfs_dir_.path().c_str(), "tmpfs", 0, ""),
+ SyscallSucceeds());
+
+ // Create a new file in the tmpfs mount.
+ file_ = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(tmpfs_dir_.path(), kContents, 0777));
+ filename_ = Basename(file_.path());
+ }
+
+ TempPath tmpfs_dir_;
+ TempPath file_;
+ std::string filename_;
+};
+
+// Provide a function to convert bytes to hex string, since
+// absl::BytesToHexString does not seem to be compatible with golang
+// hex.DecodeString used in verity due to zero-padding.
+std::string BytesToHexString(uint8_t bytes[], int size) {
+ std::stringstream ss;
+ ss << std::hex;
+ for (int i = 0; i < size; ++i) {
+ ss << std::setw(2) << std::setfill('0') << static_cast<int>(bytes[i]);
+ }
+ return ss.str();
+}
+
+TEST_F(IoctlTest, Enable) {
+ // Mount a verity fs on the existing tmpfs mount.
+ std::string mount_opts = "lower_path=" + tmpfs_dir_.path();
+ auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(
+ mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()),
+ SyscallSucceeds());
+
+ // Confirm that the verity flag is absent.
+ int flag = 0;
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777));
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds());
+ EXPECT_EQ(flag & FS_VERITY_FL, 0);
+
+ // Enable the file and confirm that the verity flag is present.
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds());
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds());
+ EXPECT_EQ(flag & FS_VERITY_FL, FS_VERITY_FL);
+}
+
+TEST_F(IoctlTest, Measure) {
+ // Mount a verity fs on the existing tmpfs mount.
+ std::string mount_opts = "lower_path=" + tmpfs_dir_.path();
+ auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(
+ mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()),
+ SyscallSucceeds());
+
+ // Confirm that the file cannot be measured.
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777));
+ uint8_t digest_array[sizeof(struct fsverity_digest) + kMaxDigestSize] = {0};
+ struct fsverity_digest* digest =
+ reinterpret_cast<struct fsverity_digest*>(digest_array);
+ digest->digest_size = kMaxDigestSize;
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest),
+ SyscallFailsWithErrno(ENODATA));
+
+ // Enable the file and confirm that the file can be measured.
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds());
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest),
+ SyscallSucceeds());
+ EXPECT_EQ(digest->digest_size, kDefaultDigestSize);
+}
+
+TEST_F(IoctlTest, Mount) {
+ // Mount a verity fs on the existing tmpfs mount.
+ std::string mount_opts = "lower_path=" + tmpfs_dir_.path();
+ auto verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(
+ mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()),
+ SyscallSucceeds());
+
+ // Enable both the file and the directory.
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777));
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds());
+ auto const dir_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(verity_dir.path(), O_RDONLY, 0777));
+ ASSERT_THAT(ioctl(dir_fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds());
+
+ // Measure the root hash.
+ uint8_t digest_array[sizeof(struct fsverity_digest) + kMaxDigestSize] = {0};
+ struct fsverity_digest* digest =
+ reinterpret_cast<struct fsverity_digest*>(digest_array);
+ digest->digest_size = kMaxDigestSize;
+ ASSERT_THAT(ioctl(dir_fd.get(), FS_IOC_MEASURE_VERITY, digest),
+ SyscallSucceeds());
+
+ // Mount a verity fs with specified root hash.
+ mount_opts +=
+ ",root_hash=" + BytesToHexString(digest->digest, digest->digest_size);
+ auto verity_with_hash_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(mount("", verity_with_hash_dir.path().c_str(), "verity", 0,
+ mount_opts.c_str()),
+ SyscallSucceeds());
+
+ // Make sure the file can be open and read in the mounted verity fs.
+ auto const verity_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(verity_with_hash_dir.path(), filename_), O_RDONLY, 0777));
+ char buf[16];
+ EXPECT_THAT(ReadFd(fd.get(), buf, sizeof(kContents)), SyscallSucceeds());
+
+ // Verity directories should not be deleted. Release the TempPath objects to
+ // prevent those directories from being deleted by the destructor.
+ verity_dir.release();
+ verity_with_hash_dir.release();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/verity_mount.cc b/test/syscalls/linux/verity_mount.cc
new file mode 100644
index 000000000..e73dd5599
--- /dev/null
+++ b/test/syscalls/linux/verity_mount.cc
@@ -0,0 +1,53 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/mount.h>
+
+#include <iomanip>
+#include <sstream>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Mount verity file system on an existing gofer mount.
+TEST(MountTest, MountExisting) {
+ // Verity is implemented in VFS2.
+ SKIP_IF(IsRunningWithVFS1());
+
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ // Mount a new tmpfs file system.
+ auto const tmpfs_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(mount("", tmpfs_dir.path().c_str(), "tmpfs", 0, ""),
+ SyscallSucceeds());
+
+ // Mount a verity file system on the existing gofer mount.
+ auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string opts = "lower_path=" + tmpfs_dir.path();
+ EXPECT_THAT(mount("", verity_dir.path().c_str(), "verity", 0, opts.c_str()),
+ SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
index 19d05998e..1a282e371 100644
--- a/test/syscalls/linux/vfork.cc
+++ b/test/syscalls/linux/vfork.cc
@@ -87,7 +87,7 @@ TEST(VforkTest, ParentStopsUntilChildExits) {
EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0));
}
-TEST(VforkTest, ParentStopsUntilChildExecves_NoRandomSave) {
+TEST(VforkTest, ParentStopsUntilChildExecves) {
ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"};
char* const* const child_argv = owned_child_argv.get();
@@ -127,7 +127,7 @@ TEST(VforkTest, ParentStopsUntilChildExecves_NoRandomSave) {
// A vfork child does not unstop the parent a second time when it exits after
// exec.
-TEST(VforkTest, ExecedChildExitDoesntUnstopParent_NoRandomSave) {
+TEST(VforkTest, ExecedChildExitDoesntUnstopParent) {
ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"};
char* const* const child_argv = owned_child_argv.get();
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index a953a55fe..dd8067807 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -107,7 +107,7 @@ TEST_F(XattrTest, XattrInvalidPrefix) {
// Do not allow save/restore cycles after making the test file read-only, as
// the restore will fail to open it with r/w permissions.
-TEST_F(XattrTest, XattrReadOnly_NoRandomSave) {
+TEST_F(XattrTest, XattrReadOnly) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
@@ -138,7 +138,7 @@ TEST_F(XattrTest, XattrReadOnly_NoRandomSave) {
// Do not allow save/restore cycles after making the test file write-only, as
// the restore will fail to open it with r/w permissions.
-TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) {
+TEST_F(XattrTest, XattrWriteOnly) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
diff --git a/test/util/BUILD b/test/util/BUILD
index e561f3daa..383de00ed 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -94,6 +94,7 @@ cc_library(
":file_descriptor",
":posix_error",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
gtest,
],
)
@@ -368,3 +369,20 @@ cc_library(
testonly = 1,
hdrs = ["temp_umask.h"],
)
+
+cc_library(
+ name = "cgroup_util",
+ testonly = 1,
+ srcs = ["cgroup_util.cc"],
+ hdrs = ["cgroup_util.h"],
+ deps = [
+ ":cleanup",
+ ":fs_util",
+ ":mount_util",
+ ":posix_error",
+ ":temp_path",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/test/util/cgroup_util.cc b/test/util/cgroup_util.cc
new file mode 100644
index 000000000..d8d3fe471
--- /dev/null
+++ b/test/util/cgroup_util.cc
@@ -0,0 +1,236 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/cgroup_util.h"
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+#include "test/util/fs_util.h"
+#include "test/util/mount_util.h"
+
+namespace gvisor {
+namespace testing {
+
+Cgroup::Cgroup(std::string path) : cgroup_path_(path) {
+ id_ = ++Cgroup::next_id_;
+ std::cerr << absl::StreamFormat("[cg#%d] <= %s", id_, cgroup_path_)
+ << std::endl;
+}
+
+PosixErrorOr<std::string> Cgroup::ReadControlFile(
+ absl::string_view name) const {
+ std::string buf;
+ RETURN_IF_ERRNO(GetContents(Relpath(name), &buf));
+
+ const std::string alias_path = absl::StrFormat("[cg#%d]/%s", id_, name);
+ std::cerr << absl::StreamFormat("<contents of %s>", alias_path) << std::endl;
+ std::cerr << buf;
+ std::cerr << absl::StreamFormat("<end of %s>", alias_path) << std::endl;
+
+ return buf;
+}
+
+PosixErrorOr<int64_t> Cgroup::ReadIntegerControlFile(
+ absl::string_view name) const {
+ ASSIGN_OR_RETURN_ERRNO(const std::string buf, ReadControlFile(name));
+ ASSIGN_OR_RETURN_ERRNO(const int64_t val, Atoi<int64_t>(buf));
+ return val;
+}
+
+PosixError Cgroup::WriteControlFile(absl::string_view name,
+ const std::string& value) const {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(Relpath(name), O_WRONLY));
+ RETURN_ERROR_IF_SYSCALL_FAIL(WriteFd(fd.get(), value.c_str(), value.size()));
+ return NoError();
+}
+
+PosixError Cgroup::WriteIntegerControlFile(absl::string_view name,
+ int64_t value) const {
+ return WriteControlFile(name, absl::StrCat(value));
+}
+
+PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Procs() const {
+ ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("cgroup.procs"));
+ return ParsePIDList(buf);
+}
+
+PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Tasks() const {
+ ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("tasks"));
+ return ParsePIDList(buf);
+}
+
+PosixError Cgroup::ContainsCallingProcess() const {
+ ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> procs, Procs());
+ ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> tasks, Tasks());
+ const pid_t pid = getpid();
+ const pid_t tid = syscall(SYS_gettid);
+ if (!procs.contains(pid)) {
+ return PosixError(
+ ENOENT, absl::StrFormat("Cgroup doesn't contain process %d", pid));
+ }
+ if (!tasks.contains(tid)) {
+ return PosixError(ENOENT,
+ absl::StrFormat("Cgroup doesn't contain task %d", tid));
+ }
+ return NoError();
+}
+
+PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::ParsePIDList(
+ absl::string_view data) const {
+ absl::flat_hash_set<pid_t> res;
+ std::vector<absl::string_view> lines = absl::StrSplit(data, '\n');
+ for (const std::string_view& line : lines) {
+ if (line.empty()) {
+ continue;
+ }
+ ASSIGN_OR_RETURN_ERRNO(const int32_t pid, Atoi<int32_t>(line));
+ res.insert(static_cast<pid_t>(pid));
+ }
+ return res;
+}
+
+int64_t Cgroup::next_id_ = 0;
+
+PosixErrorOr<Cgroup> Mounter::MountCgroupfs(std::string mopts) {
+ ASSIGN_OR_RETURN_ERRNO(TempPath mountpoint,
+ TempPath::CreateDirIn(root_.path()));
+ ASSIGN_OR_RETURN_ERRNO(
+ Cleanup mount, Mount("none", mountpoint.path(), "cgroup", 0, mopts, 0));
+ const std::string mountpath = mountpoint.path();
+ std::cerr << absl::StreamFormat(
+ "Mount(\"none\", \"%s\", \"cgroup\", 0, \"%s\", 0) => OK",
+ mountpath, mopts)
+ << std::endl;
+ Cgroup cg = Cgroup(mountpath);
+ mountpoints_[cg.id()] = std::move(mountpoint);
+ mounts_[cg.id()] = std::move(mount);
+ return cg;
+}
+
+PosixError Mounter::Unmount(const Cgroup& c) {
+ auto mount = mounts_.find(c.id());
+ auto mountpoint = mountpoints_.find(c.id());
+
+ if (mount == mounts_.end() || mountpoint == mountpoints_.end()) {
+ return PosixError(
+ ESRCH, absl::StrFormat("No mount found for cgroupfs containing cg#%d",
+ c.id()));
+ }
+
+ std::cerr << absl::StreamFormat("Unmount([cg#%d])", c.id()) << std::endl;
+
+ // Simply delete the entries, their destructors will unmount and delete the
+ // mountpoint. Note the order is important to avoid errors: mount then
+ // mountpoint.
+ mounts_.erase(mount);
+ mountpoints_.erase(mountpoint);
+
+ return NoError();
+}
+
+constexpr char kProcCgroupsHeader[] =
+ "#subsys_name\thierarchy\tnum_cgroups\tenabled";
+
+PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>>
+ProcCgroupsEntries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/cgroups", &content));
+
+ bool found_header = false;
+ absl::flat_hash_map<std::string, CgroupsEntry> entries;
+ std::vector<std::string> lines = absl::StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/cgroups>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcCgroupsHeader);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/cgroups.
+ //
+ // Example entries, fields are tab separated in the real file:
+ //
+ // #subsys_name hierarchy num_cgroups enabled
+ // cpuset 12 35 1
+ // cpu 3 222 1
+ // ^ ^ ^ ^
+ // 0 1 2 3
+
+ CgroupsEntry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": \t"), absl::SkipEmpty());
+
+ entry.subsys_name = fields[0];
+ ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[1]));
+ ASSIGN_OR_RETURN_ERRNO(entry.num_cgroups, Atoi<uint64_t>(fields[2]));
+ ASSIGN_OR_RETURN_ERRNO(const int enabled, Atoi<int>(fields[3]));
+ entry.enabled = enabled != 0;
+
+ entries[entry.subsys_name] = entry;
+ }
+ std::cerr << "<end of /proc/cgroups>" << std::endl;
+
+ return entries;
+}
+
+PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>>
+ProcPIDCgroupEntries(pid_t pid) {
+ const std::string path = absl::StrFormat("/proc/%d/cgroup", pid);
+ std::string content;
+ RETURN_IF_ERRNO(GetContents(path, &content));
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries;
+ std::vector<std::string> lines = absl::StrSplit(content, '\n');
+
+ std::cerr << absl::StreamFormat("<contents of %s>", path) << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/<pid>/cgroup.
+ //
+ // Example entries:
+ //
+ // 2:cpu:/path/to/cgroup
+ // 1:memory:/
+
+ PIDCgroupEntry entry;
+ std::vector<std::string> fields =
+ absl::StrSplit(line, absl::ByChar(':'), absl::SkipEmpty());
+
+ ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[0]));
+ entry.controllers = fields[1];
+ entry.path = fields[2];
+
+ entries[entry.controllers] = entry;
+ }
+ std::cerr << absl::StreamFormat("<end of %s>", path) << std::endl;
+
+ return entries;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/cgroup_util.h b/test/util/cgroup_util.h
new file mode 100644
index 000000000..c6e4303e1
--- /dev/null
+++ b/test/util/cgroup_util.h
@@ -0,0 +1,119 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_CGROUP_UTIL_H_
+#define GVISOR_TEST_UTIL_CGROUP_UTIL_H_
+
+#include <unistd.h>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "test/util/cleanup.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+// Cgroup represents a cgroup directory on a mounted cgroupfs.
+class Cgroup {
+ public:
+ Cgroup(std::string path);
+
+ uint64_t id() const { return id_; }
+
+ std::string Relpath(absl::string_view leaf) const {
+ return JoinPath(cgroup_path_, leaf);
+ }
+
+ // Returns the contents of a cgroup control file with the given name.
+ PosixErrorOr<std::string> ReadControlFile(absl::string_view name) const;
+
+ // Reads the contents of a cgroup control with the given name, and attempts
+ // to parse it as an integer.
+ PosixErrorOr<int64_t> ReadIntegerControlFile(absl::string_view name) const;
+
+ // Writes a string to a cgroup control file.
+ PosixError WriteControlFile(absl::string_view name,
+ const std::string& value) const;
+
+ // Writes an integer value to a cgroup control file.
+ PosixError WriteIntegerControlFile(absl::string_view name,
+ int64_t value) const;
+
+ // Returns the thread ids of the leaders of thread groups managed by this
+ // cgroup.
+ PosixErrorOr<absl::flat_hash_set<pid_t>> Procs() const;
+
+ PosixErrorOr<absl::flat_hash_set<pid_t>> Tasks() const;
+
+ // ContainsCallingProcess checks whether the calling process is part of the
+ PosixError ContainsCallingProcess() const;
+
+ private:
+ PosixErrorOr<absl::flat_hash_set<pid_t>> ParsePIDList(
+ absl::string_view data) const;
+
+ static int64_t next_id_;
+ int64_t id_;
+ const std::string cgroup_path_;
+};
+
+// Mounter is a utility for creating cgroupfs mounts. It automatically manages
+// the lifetime of created mounts.
+class Mounter {
+ public:
+ Mounter(TempPath root) : root_(std::move(root)) {}
+
+ PosixErrorOr<Cgroup> MountCgroupfs(std::string mopts);
+
+ PosixError Unmount(const Cgroup& c);
+
+ private:
+ // The destruction order of these members avoids errors during cleanup. We
+ // first unmount (by executing the mounts_ cleanups), then delete the
+ // mountpoint subdirs, then delete the root.
+ TempPath root_;
+ absl::flat_hash_map<int64_t, TempPath> mountpoints_;
+ absl::flat_hash_map<int64_t, Cleanup> mounts_;
+};
+
+// Represents a line from /proc/cgroups.
+struct CgroupsEntry {
+ std::string subsys_name;
+ uint32_t hierarchy;
+ uint64_t num_cgroups;
+ bool enabled;
+};
+
+// Returns a parsed representation of /proc/cgroups.
+PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>>
+ProcCgroupsEntries();
+
+// Represents a line from /proc/<pid>/cgroup.
+struct PIDCgroupEntry {
+ uint32_t hierarchy;
+ std::string controllers;
+ std::string path;
+};
+
+// Returns a parsed representation of /proc/<pid>/cgroup.
+PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>>
+ProcPIDCgroupEntries(pid_t pid);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_CGROUP_UTIL_H_
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index 5f1ce0d8a..483ae848d 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -28,6 +28,8 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
@@ -366,6 +368,48 @@ PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath,
return files;
}
+PosixError DirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude) {
+ ASSIGN_OR_RETURN_ERRNO(auto listing, ListDir(path, false));
+
+ for (auto& expected_entry : expect) {
+ auto cursor = std::find(listing.begin(), listing.end(), expected_entry);
+ if (cursor == listing.end()) {
+ return PosixError(ENOENT, absl::StrFormat("Failed to find '%s' in '%s'",
+ expected_entry, path));
+ }
+ }
+ for (auto& excluded_entry : exclude) {
+ auto cursor = std::find(listing.begin(), listing.end(), excluded_entry);
+ if (cursor != listing.end()) {
+ return PosixError(ENOENT, absl::StrCat("File '", excluded_entry,
+ "' found in path '", path, "'"));
+ }
+ }
+ return NoError();
+}
+
+PosixError EventuallyDirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude) {
+ constexpr int kRetryCount = 100;
+ const absl::Duration kRetryDelay = absl::Milliseconds(100);
+
+ for (int i = 0; i < kRetryCount; ++i) {
+ auto res = DirContains(path, expect, exclude);
+ if (res.ok()) {
+ return res;
+ }
+ if (i < kRetryCount - 1) {
+ // Sleep if this isn't the final iteration.
+ absl::SleepFor(kRetryDelay);
+ }
+ }
+ return PosixError(ETIMEDOUT,
+ "Timed out while waiting for directory to contain files ");
+}
+
PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs,
int* undeleted_files) {
ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path));
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index 2190c3bca..bb2d1d3c8 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -129,6 +129,18 @@ PosixError WalkTree(
PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath,
bool skipdots);
+// Check that a directory contains children nodes named in expect, and does not
+// contain any children nodes named in exclude.
+PosixError DirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude);
+
+// Same as DirContains, but adds a retry. Suitable for checking a directory
+// being modified asynchronously.
+PosixError EventuallyDirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude);
+
// Attempt to recursively delete a directory or file. Returns an error and
// the number of undeleted directories and files. If either
// undeleted_dirs or undeleted_files is nullptr then it will not be used.
diff --git a/test/util/save_util.cc b/test/util/save_util.cc
index 59d47e06e..3e724d99b 100644
--- a/test/util/save_util.cc
+++ b/test/util/save_util.cc
@@ -27,23 +27,13 @@ namespace gvisor {
namespace testing {
namespace {
-std::atomic<absl::optional<bool>> cooperative_save_present;
-std::atomic<absl::optional<bool>> random_save_present;
+std::atomic<absl::optional<bool>> save_present;
-bool CooperativeSavePresent() {
- auto present = cooperative_save_present.load();
+bool SavePresent() {
+ auto present = save_present.load();
if (!present.has_value()) {
- present = getenv("GVISOR_COOPERATIVE_SAVE_TEST") != nullptr;
- cooperative_save_present.store(present);
- }
- return present.value();
-}
-
-bool RandomSavePresent() {
- auto present = random_save_present.load();
- if (!present.has_value()) {
- present = getenv("GVISOR_RANDOM_SAVE_TEST") != nullptr;
- random_save_present.store(present);
+ present = getenv("GVISOR_SAVE_TEST") != nullptr;
+ save_present.store(present);
}
return present.value();
}
@@ -52,12 +42,10 @@ std::atomic<int> save_disable;
} // namespace
-bool IsRunningWithSaveRestore() {
- return CooperativeSavePresent() || RandomSavePresent();
-}
+bool IsRunningWithSaveRestore() { return SavePresent(); }
void MaybeSave() {
- if (CooperativeSavePresent() && save_disable.load() == 0) {
+ if (SavePresent() && save_disable.load() == 0) {
internal::DoCooperativeSave();
}
}
diff --git a/tools/BUILD b/tools/BUILD
index faf310676..3861ff2a5 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -9,3 +9,11 @@ bzl_library(
"//:sandbox",
],
)
+
+bzl_library(
+ name = "deps_bzl",
+ srcs = ["deps.bzl"],
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl
index bcd8cffe7..d16376032 100644
--- a/tools/bazeldefs/go.bzl
+++ b/tools/bazeldefs/go.bzl
@@ -132,7 +132,7 @@ def go_context(ctx, goos = None, goarch = None, std = False):
runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs),
goos = go_ctx.sdk.goos,
goarch = go_ctx.sdk.goarch,
- tags = go_ctx.tags,
+ gotags = go_ctx.tags,
)
def select_goarch():
diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD
index 1cea9e1c9..81994f954 100644
--- a/tools/bigquery/BUILD
+++ b/tools/bigquery/BUILD
@@ -12,5 +12,6 @@ go_library(
deps = [
"@com_google_cloud_go_bigquery//:go_default_library",
"@org_golang_google_api//option:go_default_library",
+ "@org_golang_x_oauth2//:go_default_library",
],
)
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
index a4ca93ec2..935154acc 100644
--- a/tools/bigquery/bigquery.go
+++ b/tools/bigquery/bigquery.go
@@ -119,6 +119,14 @@ func NewBenchmark(name string, iters int) *Benchmark {
}
}
+// NewBenchmarkWithMetric creates a new sending to BigQuery, initialized with a
+// single iteration and single metric.
+func NewBenchmarkWithMetric(name, metric, unit string, value float64) *Benchmark {
+ b := NewBenchmark(name, 1)
+ b.AddMetric(metric, unit, value)
+ return b
+}
+
// NewSuite initializes a new Suite.
func NewSuite(name string, official bool) *Suite {
return &Suite{
diff --git a/tools/deps.bzl b/tools/deps.bzl
new file mode 100644
index 000000000..ed1135a9e
--- /dev/null
+++ b/tools/deps.bzl
@@ -0,0 +1,114 @@
+"""Rules for dependency checking."""
+
+# DepsInfo provides a list of dependencies found when building a target.
+DepsInfo = provider(
+ "lists dependencies encountered while building",
+ fields = {
+ "nodes": "a dict from targets to a list of their dependencies",
+ },
+)
+
+def _deps_check_impl(target, ctx):
+ # Check the target's dependencies and add any of our own deps.
+ deps = []
+ for dep in ctx.rule.attr.deps:
+ deps.append(dep)
+ nodes = {}
+ if len(deps) != 0:
+ nodes[target] = deps
+
+ # Keep and propagate each dep's providers.
+ for dep in ctx.rule.attr.deps:
+ nodes.update(dep[DepsInfo].nodes)
+
+ return [DepsInfo(nodes = nodes)]
+
+_deps_check = aspect(
+ implementation = _deps_check_impl,
+ attr_aspects = ["deps"],
+)
+
+def _is_allowed(target, allowlist, prefixes):
+ # Check for allowed prefixes.
+ for prefix in prefixes:
+ workspace, pfx = prefix.split("//", 1)
+ if len(workspace) > 0 and workspace[0] == "@":
+ workspace = workspace[1:]
+ if target.workspace_name == workspace and target.package.startswith(pfx):
+ return True
+
+ # Check the allowlist.
+ for allowed in allowlist:
+ if target == allowed.label:
+ return True
+
+ return False
+
+def _deps_test_impl(ctx):
+ nodes = {}
+ for target in ctx.attr.targets:
+ for (node_target, node_deps) in target[DepsInfo].nodes.items():
+ # Ignore any disallowed targets. This generates more useful error
+ # messages. Consider the case where A dependes on B and B depends
+ # on C, and both B and C are disallowed. Avoid emitting an error
+ # that B depends on C, when the real issue is that A depends on B.
+ if not _is_allowed(node_target.label, ctx.attr.allowed, ctx.attr.allowed_prefixes) and node_target.label != target.label:
+ continue
+ bad_deps = []
+ for dep in node_deps:
+ if not _is_allowed(dep.label, ctx.attr.allowed, ctx.attr.allowed_prefixes):
+ bad_deps.append(dep)
+ if len(bad_deps) > 0:
+ nodes[node_target] = bad_deps
+
+ # If there aren't any violations, write a passing test.
+ if len(nodes) == 0:
+ ctx.actions.write(
+ output = ctx.outputs.executable,
+ content = "#!/bin/bash\n\nexit 0\n",
+ )
+ return []
+
+ # If we're here, we've found at least one violation.
+ script_lines = [
+ "#!/bin/bash",
+ "echo Invalid dependencies found. If you\\'re sure you want to add dependencies,",
+ "echo modify this target.",
+ "echo",
+ ]
+
+ # List the violations.
+ for target, deps in nodes.items():
+ script_lines.append(
+ 'echo "{target} depends on:"'.format(target = target.label),
+ )
+ for dep in deps:
+ script_lines.append('echo "\t{dep}"'.format(dep = dep.label))
+
+ # The test must fail.
+ script_lines.append("exit 1\n")
+
+ ctx.actions.write(
+ output = ctx.outputs.executable,
+ content = "\n".join(script_lines),
+ )
+ return []
+
+# Checks that library and its deps only depends on gVisor and an allowlist of
+# other dependencies.
+deps_test = rule(
+ implementation = _deps_test_impl,
+ attrs = {
+ "targets": attr.label_list(
+ doc = "The targets to check the transitive dependencies of.",
+ aspects = [_deps_check],
+ ),
+ "allowed": attr.label_list(
+ doc = "The allowed dependency targets.",
+ ),
+ "allowed_prefixes": attr.string_list(
+ doc = "Any packages beginning with these prefixes are allowed.",
+ ),
+ },
+ test = True,
+)
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
index f44f83eab..9f620cb76 100644
--- a/tools/go_marshal/defs.bzl
+++ b/tools/go_marshal/defs.bzl
@@ -57,8 +57,7 @@ go_marshal = rule(
# marshal_deps are the dependencies requied by generated code.
marshal_deps = [
"//pkg/gohacks",
- "//pkg/safecopy",
- "//pkg/usermem",
+ "//pkg/hostarch",
"//pkg/marshal",
]
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 39394d2a7..00961c90d 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -112,10 +112,8 @@ func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string,
g.imports.add("runtime")
g.imports.add("unsafe")
g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
- g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
- g.imports.add("gvisor.dev/gvisor/pkg/usermem")
+ g.imports.add("gvisor.dev/gvisor/pkg/hostarch")
g.imports.add("gvisor.dev/gvisor/pkg/marshal")
-
return &g, nil
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 65f5ea34d..3e643e77f 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -120,16 +120,16 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
g.shift(bufVar, 1)
case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
+ g.recordUsedImport("hostarch")
+ g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
g.shift(bufVar, 2)
case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
+ g.recordUsedImport("hostarch")
+ g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
g.shift(bufVar, 4)
case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
+ g.recordUsedImport("hostarch")
+ g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
g.shift(bufVar, 8)
default:
g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
@@ -147,16 +147,16 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
g.shift(bufVar, 1)
case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
+ g.recordUsedImport("hostarch")
+ g.emit("%s = %s(hostarch.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
g.shift(bufVar, 2)
case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
+ g.recordUsedImport("hostarch")
+ g.emit("%s = %s(hostarch.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
g.shift(bufVar, 4)
case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
+ g.recordUsedImport("hostarch")
+ g.emit("%s = %s(hostarch.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
g.shift(bufVar, 8)
default:
g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
index 7525b52da..bd7741ae5 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -33,13 +33,13 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType
}
func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
+ g.recordUsedImport("gohacks")
+ g.recordUsedImport("hostarch")
g.recordUsedImport("io")
g.recordUsedImport("marshal")
g.recordUsedImport("reflect")
g.recordUsedImport("runtime")
- g.recordUsedImport("safecopy")
g.recordUsedImport("unsafe")
- g.recordUsedImport("usermem")
lenExpr := g.arrayLenExpr(a)
@@ -89,20 +89,20 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as
g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
@@ -114,7 +114,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
})
@@ -122,7 +122,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go
index b1a8622cd..345020ddc 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go
@@ -46,8 +46,8 @@ func (g *interfaceGenerator) emitMarshallableForDynamicType() {
g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.recordUsedImport("hostarch")
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
@@ -59,8 +59,8 @@ func (g *interfaceGenerator) emitMarshallableForDynamicType() {
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.recordUsedImport("hostarch")
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
})
@@ -69,8 +69,8 @@ func (g *interfaceGenerator) emitMarshallableForDynamicType() {
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.recordUsedImport("hostarch")
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
index 7edaf666c..ba4b7324e 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -29,14 +29,14 @@ func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ g.recordUsedImport("hostarch")
+ g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ g.recordUsedImport("hostarch")
+ g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ g.recordUsedImport("hostarch")
+ g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
default:
g.emit("// Explicilty cast to the underlying type before dispatching to\n")
g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
@@ -53,14 +53,14 @@ func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typ
case "int8", "uint8":
g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ g.recordUsedImport("hostarch")
+ g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+ g.recordUsedImport("hostarch")
+ g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ g.recordUsedImport("hostarch")
+ g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
default:
g.emit("// Explicilty cast to the underlying type before dispatching to\n")
g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
@@ -95,13 +95,13 @@ func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
// newtypes are always packed, so we can omit the various fallbacks required for
// non-packed structs.
func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
+ g.recordUsedImport("gohacks")
+ g.recordUsedImport("hostarch")
g.recordUsedImport("io")
g.recordUsedImport("marshal")
g.recordUsedImport("reflect")
g.recordUsedImport("runtime")
- g.recordUsedImport("safecopy")
g.recordUsedImport("unsafe")
- g.recordUsedImport("usermem")
g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
g.emit("//go:nosplit\n")
@@ -141,20 +141,20 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident)
g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
})
g.emit("}\n\n")
g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
@@ -166,7 +166,7 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident)
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
})
@@ -174,7 +174,7 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident)
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
@@ -199,7 +199,7 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident)
func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) {
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
+ g.recordUsedImport("hostarch")
g.recordUsedImport("reflect")
g.recordUsedImport("runtime")
g.recordUsedImport("unsafe")
@@ -211,7 +211,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
g.emit("//go:nosplit\n")
- g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
+ g.emit("func Copy%sIn(cc marshal.CopyContext, addr hostarch.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
g.inIndent(func() {
g.emit("count := len(dst)\n")
g.emit("if count == 0 {\n")
@@ -231,7 +231,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
g.emit("//go:nosplit\n")
- g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
+ g.emit("func Copy%sOut(cc marshal.CopyContext, addr hostarch.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
g.inIndent(func() {
g.emit("count := len(src)\n")
g.emit("if count == 0 {\n")
@@ -260,11 +260,9 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emit("}\n")
g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
- g.emitNoEscapeSliceDataPointer("&src", "val")
-
- g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
- g.emitKeepAlive("src")
- g.emit("return length, err\n")
+ g.emit("dst = dst[:size*count]\n")
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n")
+ g.emit("return size*count, nil\n")
})
g.emit("}\n\n")
@@ -279,11 +277,9 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id
g.emit("}\n")
g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
- g.emitNoEscapeSliceDataPointer("&dst", "val")
-
- g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
- g.emitKeepAlive("dst")
- g.emit("return length, err\n")
+ g.emit("src = src[:(size*count)]\n")
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n")
+ g.emit("return size*count, nil\n")
})
g.emit("}\n\n")
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
index 5f6306b8f..4c47218f1 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -270,18 +270,18 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("%s.MarshalBytes(dst)\n", g.r)
}
if thisPacked {
- g.recordUsedImport("safecopy")
+ g.recordUsedImport("gohacks")
g.recordUsedImport("unsafe")
if cond, ok := g.areFieldsPackedExpression(); ok {
g.emit("if %s {\n", cond)
g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r)
})
g.emit("} else {\n")
g.inIndent(fallback)
g.emit("}\n")
} else {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r)
}
} else {
fallback()
@@ -297,30 +297,28 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("%s.UnmarshalBytes(src)\n", g.r)
}
if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
+ g.recordUsedImport("gohacks")
if cond, ok := g.areFieldsPackedExpression(); ok {
g.emit("if %s {\n", cond)
g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
})
g.emit("} else {\n")
g.inIndent(fallback)
g.emit("}\n")
} else {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
}
} else {
fallback()
}
})
g.emit("}\n\n")
-
g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.recordUsedImport("hostarch")
+ g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
@@ -352,8 +350,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.recordUsedImport("hostarch")
+ g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
})
@@ -362,8 +360,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.emit("//go:nosplit\n")
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.recordUsedImport("hostarch")
+ g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
@@ -436,10 +434,10 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
}
g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
+ g.recordUsedImport("hostarch")
g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
- g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.emit("func Copy%sIn(cc marshal.CopyContext, addr hostarch.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
g.inIndent(func() {
g.emit("count := len(dst)\n")
g.emit("if count == 0 {\n")
@@ -496,7 +494,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
g.emit("}\n\n")
g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
- g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.emit("func Copy%sOut(cc marshal.CopyContext, addr hostarch.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
g.inIndent(func() {
g.emit("count := len(src)\n")
g.emit("if count == 0 {\n")
@@ -561,16 +559,15 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
g.recordUsedImport("reflect")
g.recordUsedImport("runtime")
g.recordUsedImport("unsafe")
+ g.recordUsedImport("gohacks")
if _, ok := g.areFieldsPackedExpression(); ok {
g.emit("if !src[0].Packed() {\n")
g.inIndent(fallback)
g.emit("}\n\n")
}
- g.emitNoEscapeSliceDataPointer("&src", "val")
-
- g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
- g.emitKeepAlive("src")
- g.emit("return length, err\n")
+ g.emit("dst = dst[:size*count]\n")
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n")
+ g.emit("return size * count, nil\n")
} else {
fallback()
}
@@ -598,19 +595,19 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType,
g.emit("return size * count, nil\n")
}
if thisPacked {
+ g.recordUsedImport("gohacks")
g.recordUsedImport("reflect")
g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
if _, ok := g.areFieldsPackedExpression(); ok {
g.emit("if !dst[0].Packed() {\n")
g.inIndent(fallback)
g.emit("}\n\n")
}
- g.emitNoEscapeSliceDataPointer("&dst", "val")
- g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
- g.emitKeepAlive("dst")
- g.emit("return length, err\n")
+ g.emit("src = src[:(size*count)]\n")
+ g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n")
+
+ g.emit("return count*size, nil\n")
} else {
fallback()
}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 6cf00843f..8f93a1de5 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -32,7 +32,7 @@ var standardImports = []string{
var sliceAPIImports = []string{
"encoding/binary",
- "gvisor.dev/gvisor/pkg/usermem",
+ "gvisor.dev/gvisor/pkg/hostarch",
}
type testGenerator struct {
@@ -143,7 +143,7 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
}
func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
- for _, name := range []string{"binary", "usermem"} {
+ for _, name := range []string{"binary", "hostarch"} {
if !g.imports.markUsed(name) {
panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
}
@@ -155,7 +155,7 @@ func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceA
g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
g.emit("buf.Reset()\n")
- g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n")
+ g.emit("if err := binary.Write(buf, hostarch.ByteOrder, x[:]); err != nil {\n")
g.inIndent(func() {
g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
})
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index 5bceacd32..e872560a9 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -15,7 +15,7 @@ go_test(
deps = [
":test",
"//pkg/binary",
- "//pkg/usermem",
+ "//pkg/hostarch",
"//tools/go_marshal/analysis",
],
)
@@ -41,6 +41,7 @@ go_test(
srcs = ["marshal_test.go"],
deps = [
":test",
+ "//pkg/hostarch",
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/syserror",
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
index 224d308c7..16f478ff7 100644
--- a/tools/go_marshal/test/benchmark_test.go
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -22,7 +22,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/tools/go_marshal/analysis"
"gvisor.dev/gvisor/tools/go_marshal/test"
)
@@ -39,10 +39,10 @@ func BenchmarkEncodingBinary(b *testing.B) {
for n := 0; n < b.N; n++ {
buf := bytes.NewBuffer(make([]byte, size))
buf.Reset()
- if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil {
+ if err := encbin.Write(buf, hostarch.ByteOrder, &s1); err != nil {
b.Error("Write:", err)
}
- if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil {
+ if err := encbin.Read(buf, hostarch.ByteOrder, &s2); err != nil {
b.Error("Read:", err)
}
}
@@ -66,8 +66,8 @@ func BenchmarkBinary(b *testing.B) {
for n := 0; n < b.N; n++ {
buf := make([]byte, 0, size)
- buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
- binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ buf = binary.Marshal(buf, hostarch.ByteOrder, &s1)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &s2)
}
b.StopTimer()
@@ -89,42 +89,42 @@ func BenchmarkMarshalManual(b *testing.B) {
buf := make([]byte, 0, s1.SizeBytes())
// Marshal
- buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev)
- buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino)
- buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink)
- buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode)
- buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID)
- buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID)
- buf = binary.AppendUint32(buf, usermem.ByteOrder, 0)
- buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev)
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec))
- buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, s1.Dev)
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, s1.Ino)
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, s1.Nlink)
+ buf = binary.AppendUint32(buf, hostarch.ByteOrder, s1.Mode)
+ buf = binary.AppendUint32(buf, hostarch.ByteOrder, s1.UID)
+ buf = binary.AppendUint32(buf, hostarch.ByteOrder, s1.GID)
+ buf = binary.AppendUint32(buf, hostarch.ByteOrder, 0)
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, s1.Rdev)
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.Size))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.Blksize))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.Blocks))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.ATime.Sec))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.ATime.Nsec))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.MTime.Sec))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.MTime.Nsec))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.CTime.Sec))
+ buf = binary.AppendUint64(buf, hostarch.ByteOrder, uint64(s1.CTime.Nsec))
// Unmarshal
- s2.Dev = usermem.ByteOrder.Uint64(buf[0:8])
- s2.Ino = usermem.ByteOrder.Uint64(buf[8:16])
- s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24])
- s2.Mode = usermem.ByteOrder.Uint32(buf[24:28])
- s2.UID = usermem.ByteOrder.Uint32(buf[28:32])
- s2.GID = usermem.ByteOrder.Uint32(buf[32:36])
+ s2.Dev = hostarch.ByteOrder.Uint64(buf[0:8])
+ s2.Ino = hostarch.ByteOrder.Uint64(buf[8:16])
+ s2.Nlink = hostarch.ByteOrder.Uint64(buf[16:24])
+ s2.Mode = hostarch.ByteOrder.Uint32(buf[24:28])
+ s2.UID = hostarch.ByteOrder.Uint32(buf[28:32])
+ s2.GID = hostarch.ByteOrder.Uint32(buf[32:36])
// Padding: buf[36:40]
- s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48])
- s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56]))
- s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64]))
- s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72]))
- s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80]))
- s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88]))
- s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96]))
- s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104]))
- s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112]))
- s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120]))
+ s2.Rdev = hostarch.ByteOrder.Uint64(buf[40:48])
+ s2.Size = int64(hostarch.ByteOrder.Uint64(buf[48:56]))
+ s2.Blksize = int64(hostarch.ByteOrder.Uint64(buf[56:64]))
+ s2.Blocks = int64(hostarch.ByteOrder.Uint64(buf[64:72]))
+ s2.ATime.Sec = int64(hostarch.ByteOrder.Uint64(buf[72:80]))
+ s2.ATime.Nsec = int64(hostarch.ByteOrder.Uint64(buf[80:88]))
+ s2.MTime.Sec = int64(hostarch.ByteOrder.Uint64(buf[88:96]))
+ s2.MTime.Nsec = int64(hostarch.ByteOrder.Uint64(buf[96:104]))
+ s2.CTime.Sec = int64(hostarch.ByteOrder.Uint64(buf[104:112]))
+ s2.CTime.Nsec = int64(hostarch.ByteOrder.Uint64(buf[112:120]))
}
b.StopTimer()
@@ -187,8 +187,8 @@ func BenchmarkBinarySlice(b *testing.B) {
for n := 0; n < b.N; n++ {
buf := make([]byte, 0, size)
- buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
- binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ buf = binary.Marshal(buf, hostarch.ByteOrder, &s1)
+ binary.Unmarshal(buf, hostarch.ByteOrder, &s2)
}
b.StopTimer()
diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD
index 2981ef196..62e0b4665 100644
--- a/tools/go_marshal/test/escape/BUILD
+++ b/tools/go_marshal/test/escape/BUILD
@@ -7,8 +7,8 @@ go_library(
testonly = 1,
srcs = ["escape.go"],
deps = [
+ "//pkg/hostarch",
"//pkg/marshal",
- "//pkg/usermem",
"//tools/go_marshal/test",
],
)
diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go
index df14ae98e..1ac606862 100644
--- a/tools/go_marshal/test/escape/escape.go
+++ b/tools/go_marshal/test/escape/escape.go
@@ -16,8 +16,8 @@
package escape
import (
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/tools/go_marshal/test"
)
@@ -29,21 +29,21 @@ func (*dummyCopyContext) CopyScratchBuffer(size int) []byte {
return make([]byte, size)
}
-func (*dummyCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+func (*dummyCopyContext) CopyOutBytes(addr hostarch.Addr, b []byte) (int, error) {
return len(b), nil
}
-func (*dummyCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+func (*dummyCopyContext) CopyInBytes(addr hostarch.Addr, b []byte) (int, error) {
return len(b), nil
}
-func (t *dummyCopyContext) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) {
+func (t *dummyCopyContext) MarshalBytes(addr hostarch.Addr, marshallable marshal.Marshallable) {
buf := t.CopyScratchBuffer(marshallable.SizeBytes())
marshallable.MarshalBytes(buf)
t.CopyOutBytes(addr, buf)
}
-func (t *dummyCopyContext) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) {
+func (t *dummyCopyContext) MarshalUnsafe(addr hostarch.Addr, marshallable marshal.Marshallable) {
buf := t.CopyScratchBuffer(marshallable.SizeBytes())
marshallable.MarshalUnsafe(buf)
t.CopyOutBytes(addr, buf)
@@ -53,14 +53,14 @@ func (t *dummyCopyContext) MarshalUnsafe(addr usermem.Addr, marshallable marshal
//go:nosplit
func doCopyIn(t *dummyCopyContext) {
var stat test.Stat
- stat.CopyIn(t, usermem.Addr(0xf000ba12))
+ stat.CopyIn(t, hostarch.Addr(0xf000ba12))
}
// +checkescape:all
//go:nosplit
func doCopyOut(t *dummyCopyContext) {
var stat test.Stat
- stat.CopyOut(t, usermem.Addr(0xf000ba12))
+ stat.CopyOut(t, hostarch.Addr(0xf000ba12))
}
// +mustescape:builtin
@@ -70,7 +70,7 @@ func doMarshalBytesDirect(t *dummyCopyContext) {
var stat test.Stat
buf := t.CopyScratchBuffer(stat.SizeBytes())
stat.MarshalBytes(buf)
- t.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+ t.CopyOutBytes(hostarch.Addr(0xf000ba12), buf)
}
// +mustescape:builtin
@@ -80,7 +80,7 @@ func doMarshalUnsafeDirect(t *dummyCopyContext) {
var stat test.Stat
buf := t.CopyScratchBuffer(stat.SizeBytes())
stat.MarshalUnsafe(buf)
- t.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+ t.CopyOutBytes(hostarch.Addr(0xf000ba12), buf)
}
// +mustescape:local,heap
@@ -88,7 +88,7 @@ func doMarshalUnsafeDirect(t *dummyCopyContext) {
//go:nosplit
func doMarshalBytesViaMarshallable(t *dummyCopyContext) {
var stat test.Stat
- t.MarshalBytes(usermem.Addr(0xf000ba12), &stat)
+ t.MarshalBytes(hostarch.Addr(0xf000ba12), &stat)
}
// +mustescape:local,heap
@@ -96,5 +96,5 @@ func doMarshalBytesViaMarshallable(t *dummyCopyContext) {
//go:nosplit
func doMarshalUnsafeViaMarshallable(t *dummyCopyContext) {
var stat test.Stat
- t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat)
+ t.MarshalUnsafe(hostarch.Addr(0xf000ba12), &stat)
}
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
index 733689c79..43bafbf96 100644
--- a/tools/go_marshal/test/marshal_test.go
+++ b/tools/go_marshal/test/marshal_test.go
@@ -27,6 +27,7 @@ import (
"unsafe"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/syserror"
@@ -47,7 +48,7 @@ func (t *mockCopyContext) populate(val interface{}) {
var buf bytes.Buffer
// Use binary.Write so we aren't testing go-marshal against its own
// potentially buggy implementation.
- if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil {
+ if err := binary.Write(&buf, hostarch.ByteOrder, val); err != nil {
panic(err)
}
t.taskMem.Bytes = buf.Bytes()
@@ -71,14 +72,14 @@ func (t *mockCopyContext) CopyScratchBuffer(size int) []byte {
// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. The implementation
// completely ignores the target address and stores a copy of b in its
// internally buffer, overriding any previous contents.
-func (t *mockCopyContext) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) {
+func (t *mockCopyContext) CopyOutBytes(_ hostarch.Addr, b []byte) (int, error) {
return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{})
}
// CopyInBytes implements marshal.CopyContext.CopyInBytes. The implementation
// completely ignores the source address and always fills b from the begining of
// its internal buffer.
-func (t *mockCopyContext) CopyInBytes(_ usermem.Addr, b []byte) (int, error) {
+func (t *mockCopyContext) CopyInBytes(_ hostarch.Addr, b []byte) (int, error) {
return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{})
}
@@ -91,7 +92,7 @@ func unsafeMemory(m marshal.Marshallable) []byte {
// since the layout isn't packed. Allocate a temporary buffer
// and marshal instead.
var buf bytes.Buffer
- if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ if err := binary.Write(&buf, hostarch.ByteOrder, m); err != nil {
panic(err)
}
return buf.Bytes()
@@ -130,7 +131,7 @@ func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte {
// since the layout isn't packed. Allocate a temporary buffer
// and marshal instead.
var buf bytes.Buffer
- if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ if err := binary.Write(&buf, hostarch.ByteOrder, m); err != nil {
panic(err)
}
return buf.Bytes()
@@ -176,7 +177,7 @@ func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) {
cc.populate(src)
cc.setLimit(limit)
- n, err := dst.CopyIn(&cc, usermem.Addr(0))
+ n, err := dst.CopyIn(&cc, hostarch.Addr(0))
if n != limit {
t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -206,7 +207,7 @@ func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
var cc mockCopyContext
cc.setLimit(limit)
- n, err := src.CopyOut(&cc, usermem.Addr(0))
+ n, err := src.CopyOut(&cc, hostarch.Addr(0))
if n != limit {
t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -227,7 +228,7 @@ func copyOutN(t *testing.T, src marshal.Marshallable, limit int) {
var cc mockCopyContext
cc.setLimit(limit)
- n, err := src.CopyOutN(&cc, usermem.Addr(0), limit)
+ n, err := src.CopyOutN(&cc, hostarch.Addr(0), limit)
if err != nil {
t.Errorf("CopyOut returned unexpected error: %v", err)
}
@@ -304,18 +305,18 @@ func TestLimitedMarshalling(t *testing.T) {
func TestLimitedSliceMarshalling(t *testing.T) {
types := []struct {
arrayPtrType reflect.Type
- copySliceIn func(cc marshal.CopyContext, addr usermem.Addr, dstSlice interface{}) (int, error)
- copySliceOut func(cc marshal.CopyContext, addr usermem.Addr, srcSlice interface{}) (int, error)
+ copySliceIn func(cc marshal.CopyContext, addr hostarch.Addr, dstSlice interface{}) (int, error)
+ copySliceOut func(cc marshal.CopyContext, addr hostarch.Addr, srcSlice interface{}) (int, error)
unsafeMemory func(arrPtr interface{}) []byte
}{
// Packed types.
{
reflect.TypeOf((*[20]test.Stat)(nil)),
- func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, dst interface{}) (int, error) {
slice := dst.(*[20]test.Stat)[:]
return test.CopyStatSliceIn(cc, addr, slice)
},
- func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, src interface{}) (int, error) {
slice := src.(*[20]test.Stat)[:]
return test.CopyStatSliceOut(cc, addr, slice)
},
@@ -326,11 +327,11 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[1]test.Stat)(nil)),
- func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, dst interface{}) (int, error) {
slice := dst.(*[1]test.Stat)[:]
return test.CopyStatSliceIn(cc, addr, slice)
},
- func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, src interface{}) (int, error) {
slice := src.(*[1]test.Stat)[:]
return test.CopyStatSliceOut(cc, addr, slice)
},
@@ -341,11 +342,11 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[5]test.SignalSetAlias)(nil)),
- func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, dst interface{}) (int, error) {
slice := dst.(*[5]test.SignalSetAlias)[:]
return test.CopySignalSetAliasSliceIn(cc, addr, slice)
},
- func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, src interface{}) (int, error) {
slice := src.(*[5]test.SignalSetAlias)[:]
return test.CopySignalSetAliasSliceOut(cc, addr, slice)
},
@@ -357,11 +358,11 @@ func TestLimitedSliceMarshalling(t *testing.T) {
// Non-packed types.
{
reflect.TypeOf((*[20]test.Type1)(nil)),
- func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, dst interface{}) (int, error) {
slice := dst.(*[20]test.Type1)[:]
return test.CopyType1SliceIn(cc, addr, slice)
},
- func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, src interface{}) (int, error) {
slice := src.(*[20]test.Type1)[:]
return test.CopyType1SliceOut(cc, addr, slice)
},
@@ -372,11 +373,11 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[1]test.Type1)(nil)),
- func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, dst interface{}) (int, error) {
slice := dst.(*[1]test.Type1)[:]
return test.CopyType1SliceIn(cc, addr, slice)
},
- func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, src interface{}) (int, error) {
slice := src.(*[1]test.Type1)[:]
return test.CopyType1SliceOut(cc, addr, slice)
},
@@ -387,11 +388,11 @@ func TestLimitedSliceMarshalling(t *testing.T) {
},
{
reflect.TypeOf((*[7]test.Type8)(nil)),
- func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, dst interface{}) (int, error) {
slice := dst.(*[7]test.Type8)[:]
return test.CopyType8SliceIn(cc, addr, slice)
},
- func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) {
+ func(cc marshal.CopyContext, addr hostarch.Addr, src interface{}) (int, error) {
slice := src.(*[7]test.Type8)[:]
return test.CopyType8SliceOut(cc, addr, slice)
},
@@ -444,7 +445,7 @@ func TestLimitedSliceMarshalling(t *testing.T) {
cc.populate(expected)
cc.setLimit(limit)
- n, err := tt.copySliceIn(&cc, usermem.Addr(0), actual)
+ n, err := tt.copySliceIn(&cc, hostarch.Addr(0), actual)
if n != limit {
t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -498,7 +499,7 @@ func TestLimitedSliceMarshalling(t *testing.T) {
cc.populate(expected)
cc.setLimit(limit)
- n, err := tt.copySliceOut(&cc, usermem.Addr(0), expected)
+ n, err := tt.copySliceOut(&cc, hostarch.Addr(0), expected)
if n != limit {
t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
}
@@ -523,14 +524,14 @@ func TestDynamicTypeStruct(t *testing.T) {
var cc mockCopyContext
cc.setLimit(t12.SizeBytes())
- if _, err := t12.CopyOut(&cc, usermem.Addr(0)); err != nil {
+ if _, err := t12.CopyOut(&cc, hostarch.Addr(0)); err != nil {
t.Fatalf("cc.CopyOut faile: %v", err)
}
res := test.Type12Dynamic{
Y: make([]primitive.Int64, len(t12.Y)),
}
- res.CopyIn(&cc, usermem.Addr(0))
+ res.CopyIn(&cc, hostarch.Addr(0))
if !reflect.DeepEqual(t12, res) {
t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %+v, after = %+v", t12, res)
}
@@ -541,12 +542,12 @@ func TestDynamicTypeIdentifier(t *testing.T) {
var cc mockCopyContext
cc.setLimit(s.SizeBytes())
- if _, err := s.CopyOut(&cc, usermem.Addr(0)); err != nil {
+ if _, err := s.CopyOut(&cc, hostarch.Addr(0)); err != nil {
t.Fatalf("cc.CopyOut faile: %v", err)
}
res := test.Type13Dynamic(make([]byte, len(s)))
- res.CopyIn(&cc, usermem.Addr(0))
+ res.CopyIn(&cc, hostarch.Addr(0))
if res != s {
t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %s, after = %s", s, res)
}
diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go
index 8b4bff3b6..2b3c03fec 100644
--- a/tools/nogo/analyzers.go
+++ b/tools/nogo/analyzers.go
@@ -83,11 +83,6 @@ var AllAnalyzers = []*analysis.Analyzer{
checklocks.Analyzer,
}
-// EscapeAnalyzers is a list of escape-related analyzers.
-var EscapeAnalyzers = []*analysis.Analyzer{
- checkescape.EscapeAnalyzer,
-}
-
func register(all []*analysis.Analyzer) {
// Register all fact types.
//
@@ -129,5 +124,4 @@ func init() {
// Register lists.
register(AllAnalyzers)
- register(EscapeAnalyzers)
}
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
index 69bdfe502..4194770be 100644
--- a/tools/nogo/check/main.go
+++ b/tools/nogo/check/main.go
@@ -31,7 +31,6 @@ var (
stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)")
findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)")
factsOutput = flag.String("facts", "", "output file for facts (optional)")
- escapesOutput = flag.String("escapes", "", "output file for escapes (optional)")
)
func loadConfig(file string, config interface{}) interface{} {
@@ -66,25 +65,13 @@ func main() {
// Run the configuration.
if *stdlibFile != "" {
- // Perform basic analysis.
+ // Perform stdlib analysis.
c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig)
findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers)
-
} else if *packageFile != "" {
- // Perform basic analysis.
+ // Perform standard analysis.
c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig)
findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil)
-
- // Do we need to do escape analysis?
- if *escapesOutput != "" {
- escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil)
- if err != nil {
- log.Fatalf("error performing escape analysis: %v", err)
- }
- if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil {
- log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err)
- }
- }
} else {
log.Fatalf("please provide at least one of package or stdlib!")
}
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
index 0c48a7a5a..be8b82f9c 100644
--- a/tools/nogo/defs.bzl
+++ b/tools/nogo/defs.bzl
@@ -120,7 +120,7 @@ def _nogo_stdlib_impl(ctx):
Srcs = [f.path for f in go_ctx.stdlib_srcs],
GOOS = go_ctx.goos,
GOARCH = go_ctx.goarch,
- Tags = go_ctx.tags,
+ Tags = go_ctx.gotags,
)
config_file = ctx.actions.declare_file(ctx.label.name + ".cfg")
ctx.actions.write(config_file, config.to_json())
@@ -174,7 +174,6 @@ NogoInfo = provider(
fields = {
"facts": "serialized package facts",
"raw_findings": "raw package findings (if relevant)",
- "escapes": "escape-only findings (if relevant)",
"importpath": "package import path",
"binaries": "package binary files",
"srcs": "srcs (for go_test support)",
@@ -281,14 +280,13 @@ def _nogo_aspect_impl(target, ctx):
go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch)
facts = ctx.actions.declare_file(target.label.name + ".facts")
raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings")
- escapes = ctx.actions.declare_file(target.label.name + ".escapes")
config = struct(
ImportPath = importpath,
GoFiles = [src.path for src in srcs if src.path.endswith(".go")],
NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")],
GOOS = go_ctx.goos,
GOARCH = go_ctx.goarch,
- Tags = go_ctx.tags,
+ Tags = go_ctx.gotags,
FactMap = fact_map,
ImportMap = import_map,
StdlibFacts = stdlib_facts.path,
@@ -298,7 +296,7 @@ def _nogo_aspect_impl(target, ctx):
inputs.append(config_file)
ctx.actions.run(
inputs = inputs,
- outputs = [facts, raw_findings, escapes],
+ outputs = [facts, raw_findings],
tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool),
executable = ctx.files._nogo_check[0],
mnemonic = "NogoAnalysis",
@@ -309,7 +307,6 @@ def _nogo_aspect_impl(target, ctx):
"-package=%s" % config_file.path,
"-findings=%s" % raw_findings.path,
"-facts=%s" % facts.path,
- "-escapes=%s" % escapes.path,
],
)
@@ -322,15 +319,16 @@ def _nogo_aspect_impl(target, ctx):
all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings]
# Return the package facts as output.
- return [NogoInfo(
- facts = facts,
- raw_findings = all_raw_findings,
- escapes = escapes,
- importpath = importpath,
- binaries = binaries,
- srcs = srcs,
- deps = deps,
- )]
+ return [
+ NogoInfo(
+ facts = facts,
+ raw_findings = all_raw_findings,
+ importpath = importpath,
+ binaries = binaries,
+ srcs = srcs,
+ deps = deps,
+ ),
+ ]
nogo_aspect = go_rule(
aspect,
@@ -367,7 +365,6 @@ def _nogo_test_impl(ctx):
if len(ctx.attr.deps) != 1:
fail("nogo_test requires exactly one dep.")
raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings
- escapes = ctx.attr.deps[0][NogoInfo].escapes
# Build a step that applies the configuration.
config_srcs = ctx.attr.config[NogoConfigInfo].srcs
@@ -409,8 +406,6 @@ def _nogo_test_impl(ctx):
# pays attention to the mnemoic above, so this must be
# what is expected by the tooling.
nogo_findings = depset([findings]),
- # Expose all escape analysis findings (see above).
- nogo_escapes = depset([escapes]),
)]
nogo_test = rule(
@@ -432,3 +427,18 @@ nogo_test = rule(
},
test = True,
)
+
+def _nogo_aspect_tricorder_impl(target, ctx):
+ if ctx.rule.kind != "nogo_test" or OutputGroupInfo not in target:
+ return []
+ if not hasattr(target[OutputGroupInfo], "nogo_findings"):
+ return []
+ return [
+ OutputGroupInfo(tricorder = target[OutputGroupInfo].nogo_findings),
+ ]
+
+# Trivial aspect that forwards the findings from a nogo_test rule to
+# go/tricorder, which reads from the `tricorder` output group.
+nogo_aspect_tricorder = aspect(
+ implementation = _nogo_aspect_tricorder_impl,
+)
diff --git a/website/BUILD b/website/BUILD
index b5b3f6df6..6f52e9208 100644
--- a/website/BUILD
+++ b/website/BUILD
@@ -14,7 +14,7 @@ docker_image(
tags = [
"local",
"manual",
- "nosandbox",
+ "no-sandbox",
],
)
@@ -69,7 +69,7 @@ genrule(
tags = [
"local",
"manual",
- "nosandbox",
+ "no-sandbox",
],
)